jax.lax.cond#
- jax.lax.cond(pred, true_fun, false_fun, *operands, operand=<object object>)[原始碼]#
有條件地套用
true_fun
或false_fun
。包裝 XLA 的 Conditional 運算子。
如果提供的引數型別正確,
cond()
具有與此 Python 實作等效的語意,其中pred
必須是純量型別def cond(pred, true_fun, false_fun, *operands): if pred: return true_fun(*operands) else: return false_fun(*operands)
與
jax.lax.select()
相反,使用cond
表示只會執行兩個分支之一(取決於編譯器重寫和最佳化)。但是,當使用vmap()
轉換以對一批述詞進行運算時,cond
會轉換為select()
。- 參數:
pred – 布林純量型別,指示要套用哪個分支函式。
true_fun (Callable) – 函式 (A -> B),如果
pred
為 True,則套用此函式。false_fun (Callable) – 函式 (A -> B),如果
pred
為 False,則套用此函式。operands – 運算元 (A) 輸入到任一分支,取決於
pred
。型別可以是純量、陣列或任何 pytree(巢狀 Python tuple/list/dict)。
- 傳回值:
值 (B) 為
true_fun(*operands)
或false_fun(*operands)
其中之一,取決於pred
的值。型別可以是純量、陣列或任何 pytree(巢狀 Python tuple/list/dict)。