jax.lax.select#
- jax.lax.select(pred, on_true, on_false)[原始碼]#
根據布林述詞在兩個分支之間選擇。
包裝了 XLA 的 Select 運算子。
一般而言,
select()
會導致兩個分支都進行評估,儘管編譯器可能會在可能的情況下省略計算。對於類似的函式,通常只評估單個分支,請參閱cond()
。- 參數:
pred (ArrayLike) – 布林陣列
on_true (ArrayLike) – 陣列,包含在
pred
為 True 時返回的條目。必須具有與pred
相同的形狀,以及與on_false
相同的形狀和 dtype。on_false (ArrayLike) – 陣列,包含在
pred
為 False 時返回的條目。必須具有與pred
相同的形狀,以及與on_true
相同的形狀和 dtype。
- 返回:
與
on_true
和on_false
具有相同形狀和 dtype 的陣列。- 返回類型:
result