jax.lax.switch#
- jax.lax.switch(index, branches, *operands, operand=<object object>)[source]#
根據
index
套用branches
中的其中一個分支。如果
index
超出範圍,則會將其箝制在範圍內。具有以下 Python 的語意
def switch(index, branches, *operands): index = clamp(0, index, len(branches) - 1) return branches[index](*operands)
在內部,這會包裝 XLA 的 Conditional 運算子。但是,當使用
vmap()
轉換以對一批述詞進行操作時,cond
會轉換為select()
。- 參數:
index – 整數純量類型,指示要套用哪個分支函數。
branches (Sequence[Callable]) – 函數序列 (A -> B),根據
index
套用。所有分支都必須傳回相同的輸出結構。operands – 運算元 (A),輸入到要套用的分支。
- 傳回值:
根據
index
選取的分支的branch(*operands)
值 (B)。