jax.lax.switch#

jax.lax.switch(index, branches, *operands, operand=<object object>)[source]#

根據 index 套用 branches 中的其中一個分支。

如果 index 超出範圍,則會將其箝制在範圍內。

具有以下 Python 的語意

def switch(index, branches, *operands):
  index = clamp(0, index, len(branches) - 1)
  return branches[index](*operands)

在內部,這會包裝 XLA 的 Conditional 運算子。但是,當使用 vmap() 轉換以對一批述詞進行操作時,cond 會轉換為 select()

參數:
  • index – 整數純量類型,指示要套用哪個分支函數。

  • branches (Sequence[Callable]) – 函數序列 (A -> B),根據 index 套用。所有分支都必須傳回相同的輸出結構。

  • operands – 運算元 (A),輸入到要套用的分支。

傳回值:

根據 index 選取的分支的 branch(*operands) 值 (B)。