jax.numpy.select#
- jax.numpy.select(condlist, choicelist, default=0)[原始碼]#
根據一系列條件選擇值。
JAX 實作的
numpy.select()
,以jax.lax.select_n()
實作- 參數:
condlist (Sequence[ArrayLike]) – 類陣列條件的序列。所有條目必須彼此廣播相容。
choicelist (Sequence[ArrayLike]) – 要選擇的類陣列值的序列。必須與
condlist
具有相同的長度,且所有條目必須與condlist
的條目廣播相容。default (ArrayLike) – 當每個條件都為 False 時要傳回的值 (預設值:0)。
- 傳回:
從
choicelist
中選取的值的陣列,對應於每個位置中condlist
中的第一個True
條目。- 傳回型別:
參見
jax.numpy.where()
:根據單一條件在兩個值之間選擇。jax.lax.select_n()
:根據索引在 N 個值之間選擇。
範例
>>> condlist = [ ... jnp.array([False, True, False, False]), ... jnp.array([True, False, False, False]), ... jnp.array([False, True, True, False]), ... ] >>> choicelist = [ ... jnp.array([1, 2, 3, 4]), ... jnp.array([10, 20, 30, 40]), ... jnp.array([100, 200, 300, 400]), ... ] >>> jnp.select(condlist, choicelist, default=0) Array([ 10, 2, 300, 0], dtype=int32)
這在邏輯上等同於以下巢狀
where
陳述式>>> default = 0 >>> jnp.where(condlist[0], ... choicelist[0], ... jnp.where(condlist[1], ... choicelist[1], ... jnp.where(condlist[2], ... choicelist[2], ... default))) Array([ 10, 2, 300, 0], dtype=int32)
但是,為了效率,它以
jax.lax.select_n()
實作。