jax.numpy.select#

jax.numpy.select(condlist, choicelist, default=0)[原始碼]#

根據一系列條件選擇值。

JAX 實作的 numpy.select(),以 jax.lax.select_n() 實作

參數:
  • condlist (Sequence[ArrayLike]) – 類陣列條件的序列。所有條目必須彼此廣播相容。

  • choicelist (Sequence[ArrayLike]) – 要選擇的類陣列值的序列。必須與 condlist 具有相同的長度,且所有條目必須與 condlist 的條目廣播相容。

  • default (ArrayLike) – 當每個條件都為 False 時要傳回的值 (預設值:0)。

傳回:

choicelist 中選取的值的陣列,對應於每個位置中 condlist 中的第一個 True 條目。

傳回型別:

Array

參見

範例

>>> condlist = [
...    jnp.array([False, True, False, False]),
...    jnp.array([True, False, False, False]),
...    jnp.array([False, True, True, False]),
... ]
>>> choicelist = [
...    jnp.array([1, 2, 3, 4]),
...    jnp.array([10, 20, 30, 40]),
...    jnp.array([100, 200, 300, 400]),
... ]
>>> jnp.select(condlist, choicelist, default=0)
Array([ 10,   2, 300,   0], dtype=int32)

這在邏輯上等同於以下巢狀 where 陳述式

>>> default = 0
>>> jnp.where(condlist[0],
...   choicelist[0],
...   jnp.where(condlist[1],
...     choicelist[1],
...     jnp.where(condlist[2],
...       choicelist[2],
...       default)))
Array([ 10,   2, 300,   0], dtype=int32)

但是,為了效率,它以 jax.lax.select_n() 實作。