jax.lax.select#

jax.lax.select(pred, on_true, on_false)[原始碼]#

根據布林述詞在兩個分支之間選擇。

包裝了 XLA 的 Select 運算子。

一般而言,select() 會導致兩個分支都進行評估,儘管編譯器可能會在可能的情況下省略計算。對於類似的函式,通常只評估單個分支,請參閱 cond()

參數:
  • pred (ArrayLike) – 布林陣列

  • on_true (ArrayLike) – 陣列,包含在 pred 為 True 時返回的條目。必須具有與 pred 相同的形狀,以及與 on_false 相同的形狀和 dtype。

  • on_false (ArrayLike) – 陣列,包含在 pred 為 False 時返回的條目。必須具有與 pred 相同的形狀,以及與 on_true 相同的形狀和 dtype。

返回:

on_trueon_false 具有相同形狀和 dtype 的陣列。

返回類型:

result