jax.numpy.choose#
- jax.numpy.choose(a, choices, out=None, mode='raise')[原始碼]#
透過堆疊選擇陣列的切片來建構陣列。
JAX 實作的
numpy.choose()
。此函式的語意可能令人困惑,但在最簡單的情況下,當
a
是一維陣列,choices
是二維陣列,且a
的所有條目都在界內時 (即0 <= a_i < len(choices)
),則此函式等效於以下程式碼def choose(a, choices): return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)])
在更一般的情況下,
a
可以具有任意數量的維度,而choices
可以是廣播相容陣列的任意序列。在這種情況下,同樣針對界內索引,邏輯等效於def choose(a, choices): a, *choices = jnp.broadcast_arrays(a, *choices) choices = jnp.array(choices) return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)])
唯一的額外複雜性來自
mode
引數,它控制a
中超出邊界索引的行為,如下所述。- 參數:
- 傳回:
一個陣列,其中包含來自
choices
在a
指定索引處堆疊的切片。結果的形狀為broadcast_shapes(a.shape, *(c.shape for c in choices))
。- 傳回類型:
參見
jax.lax.switch()
:根據索引在 N 個函式之間選擇。
範例
這是 1D 索引陣列與 2D 選擇陣列的最簡單情況,在這種情況下,它會從每一列中選擇索引值
>>> choices = jnp.array([[ 1, 2, 3, 4], ... [ 5, 6, 7, 8], ... [ 9, 10, 11, 12]]) >>> a = jnp.array([2, 0, 1, 0]) >>> jnp.choose(a, choices) Array([9, 2, 7, 4], dtype=int32)
mode
引數指定如何處理超出邊界索引;選項為wrap
或clip
>>> a2 = jnp.array([2, 0, 1, 4]) # last index out-of-bound >>> jnp.choose(a2, choices, mode='clip') Array([ 9, 2, 7, 12], dtype=int32) >>> jnp.choose(a2, choices, mode='wrap') Array([9, 2, 7, 8], dtype=int32)
在更一般的情況下,
choices
可以是具有任何廣播相容形狀的類陣列物件序列。>>> choice_1 = jnp.array([1, 2, 3, 4]) >>> choice_2 = 99 >>> choice_3 = jnp.array([[10], ... [20], ... [30]]) >>> a = jnp.array([[0, 1, 2, 0], ... [1, 2, 0, 1], ... [2, 0, 1, 2]]) >>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap') Array([[ 1, 99, 10, 4], [99, 20, 3, 99], [30, 2, 99, 30]], dtype=int32)