jax.lax.select_n#

jax.lax.select_n(which, *cases)[原始碼]#

從多個情況中選取陣列值。

概括了 XLA 的 Select 運算符。與 XLA 的版本不同,此運算符是可變參數的,並且可以使用整數 pred 從多種情況中進行選擇。

參數:
  • which (ArrayLike) – 決定應返回哪種情況。必須是包含布林值或整數值的陣列。可以是純量或具有與 cases 相符的形狀。對於每個陣列元素,which 的值決定了採用 cases 中的哪一種。which 必須在範圍 [0 .. len(cases)) 內;對於該範圍之外的值,行為是實作定義的。

  • *cases (ArrayLike) – 陣列情況的非空列表。所有情況都必須具有相等的資料類型和相等的形狀。

返回值:

一個陣列,其形狀和資料類型與情況相同,其值根據 which 選擇。

返回類型:

Array