jax.lax.select_n#
- jax.lax.select_n(which, *cases)[原始碼]#
從多個情況中選取陣列值。
概括了 XLA 的 Select 運算符。與 XLA 的版本不同,此運算符是可變參數的,並且可以使用整數 pred 從多種情況中進行選擇。
- 參數:
which (ArrayLike) – 決定應返回哪種情況。必須是包含布林值或整數值的陣列。可以是純量或具有與
cases
相符的形狀。對於每個陣列元素,which
的值決定了採用cases
中的哪一種。which
必須在範圍[0 .. len(cases))
內;對於該範圍之外的值,行為是實作定義的。*cases (ArrayLike) – 陣列情況的非空列表。所有情況都必須具有相等的資料類型和相等的形狀。
- 返回值:
一個陣列,其形狀和資料類型與情況相同,其值根據
which
選擇。- 返回類型: