jax.numpy.choose#

jax.numpy.choose(a, choices, out=None, mode='raise')[原始碼]#

透過堆疊選擇陣列的切片來建構陣列。

JAX 實作的 numpy.choose()

此函式的語意可能令人困惑,但在最簡單的情況下,當 a 是一維陣列,choices 是二維陣列,且 a 的所有條目都在界內時 (即 0 <= a_i < len(choices)),則此函式等效於以下程式碼

def choose(a, choices):
  return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)])

在更一般的情況下,a 可以具有任意數量的維度,而 choices 可以是廣播相容陣列的任意序列。在這種情況下,同樣針對界內索引,邏輯等效於

def choose(a, choices):
  a, *choices = jnp.broadcast_arrays(a, *choices)
  choices = jnp.array(choices)
  return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)])

唯一的額外複雜性來自 mode 引數,它控制 a 中超出邊界索引的行為,如下所述。

參數:
  • a (ArrayLike) – 整數索引的 N 維陣列。

  • choices (Array | np.ndarray | Sequence[ArrayLike]) – 陣列或陣列序列。序列中的所有陣列都必須與 a 相互廣播相容。

  • out (None | None) – JAX 未使用

  • mode (str) – 指定超出邊界索引模式;選項為 'raise' (預設)、'wrap''clip' 之一。請注意,'raise' 的預設模式與 JAX 轉換不相容。

傳回:

一個陣列,其中包含來自 choicesa 指定索引處堆疊的切片。結果的形狀為 broadcast_shapes(a.shape, *(c.shape for c in choices))

傳回類型:

Array

參見

範例

這是 1D 索引陣列與 2D 選擇陣列的最簡單情況,在這種情況下,它會從每一列中選擇索引值

>>> choices = jnp.array([[ 1,  2,  3,  4],
...                      [ 5,  6,  7,  8],
...                      [ 9, 10, 11, 12]])
>>> a = jnp.array([2, 0, 1, 0])
>>> jnp.choose(a, choices)
Array([9, 2, 7, 4], dtype=int32)

mode 引數指定如何處理超出邊界索引;選項為 wrapclip

>>> a2 = jnp.array([2, 0, 1, 4])  # last index out-of-bound
>>> jnp.choose(a2, choices, mode='clip')
Array([ 9,  2,  7, 12], dtype=int32)
>>> jnp.choose(a2, choices, mode='wrap')
Array([9, 2, 7, 8], dtype=int32)

在更一般的情況下,choices 可以是具有任何廣播相容形狀的類陣列物件序列。

>>> choice_1 = jnp.array([1, 2, 3, 4])
>>> choice_2 = 99
>>> choice_3 = jnp.array([[10],
...                       [20],
...                       [30]])
>>> a = jnp.array([[0, 1, 2, 0],
...                [1, 2, 0, 1],
...                [2, 0, 1, 2]])
>>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap')
Array([[ 1, 99, 10,  4],
       [99, 20,  3, 99],
       [30,  2, 99, 30]], dtype=int32)