jax.lax.cond#

jax.lax.cond(pred, true_fun, false_fun, *operands, operand=<object object>)[原始碼]#

有條件地套用 true_funfalse_fun

包裝 XLA 的 Conditional 運算子。

如果提供的引數型別正確,cond() 具有與此 Python 實作等效的語意,其中 pred 必須是純量型別

def cond(pred, true_fun, false_fun, *operands):
  if pred:
    return true_fun(*operands)
  else:
    return false_fun(*operands)

jax.lax.select() 相反,使用 cond 表示只會執行兩個分支之一(取決於編譯器重寫和最佳化)。但是,當使用 vmap() 轉換以對一批述詞進行運算時,cond 會轉換為 select()

參數:
  • pred – 布林純量型別,指示要套用哪個分支函式。

  • true_fun (Callable) – 函式 (A -> B),如果 pred 為 True,則套用此函式。

  • false_fun (Callable) – 函式 (A -> B),如果 pred 為 False,則套用此函式。

  • operands – 運算元 (A) 輸入到任一分支,取決於 pred。型別可以是純量、陣列或任何 pytree(巢狀 Python tuple/list/dict)。

傳回值:

值 (B) 為 true_fun(*operands)false_fun(*operands) 其中之一,取決於 pred 的值。型別可以是純量、陣列或任何 pytree(巢狀 Python tuple/list/dict)。