jax.numpy.take#

jax.numpy.take(a, indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[原始碼]#

從陣列中取出元素。

numpy.take() 的 JAX 實作,以 jax.lax.gather() 實作。mode 參數如下,JAX 的行為在超出邊界索引的情況下與 NumPy 不同;請參閱下方的 mode 參數。

參數::
  • a (ArrayLike) – 要从中取值的陣列。

  • indices (ArrayLike) – 從陣列中取值的整數索引 N 維陣列。

  • axis (int | None | None) – 沿著取值的軸。如果未指定,陣列將在套用索引之前展平。

  • mode (str | None | None) – 超出邊界索引模式,可以是 "fill""clip"。預設的 mode="fill" 會為超出邊界的索引傳回無效值 (例如 NaN);fill_value 參數可控制此值。有關 mode 選項的更多討論,請參閱 jax.numpy.ndarray.at

  • fill_value (StaticScalar | None | None) – 當 mode 為 ‘fill’ 時,針對超出邊界切片傳回的填充值。否則將忽略。對於非精確類型,預設為 NaN;對於帶符號類型,預設為最大負值;對於無符號類型,預設為最大正值;對於布林值,預設為 True。

  • unique_indices (bool) – 如果為 True,則實作會假設索引是唯一的,這可以在某些後端產生更有效率的執行。如果設定為 True 且索引不是唯一的,則輸出是未定義的。

  • indices_are_sorted (bool) – 如果為 True,則實作會假設索引是升序排序的,這可以在某些後端產生更有效率的執行。如果設定為 True 且索引未排序,則輸出是未定義的。

  • out (None | None)

返回::

a 提取的值的陣列。

返回類型::

陣列

參見

範例

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 6.]])
>>> indices = jnp.array([2, 0])

不傳遞軸會導致索引到展平的陣列中

>>> jnp.take(x, indices)
Array([3., 1.], dtype=float32)
>>> x.ravel()[indices]  # equivalent indexing syntax
Array([3., 1.], dtype=float32)

傳遞軸會導致將索引應用於沿軸的每個子陣列

>>> jnp.take(x, indices, axis=1)
Array([[3., 1.],
       [6., 4.]], dtype=float32)
>>> x[:, indices]  # equivalent indexing syntax
Array([[3., 1.],
       [6., 4.]], dtype=float32)

超出邊界的索引會以無效值填充。對於浮點輸入,這是 NaN

>>> jnp.take(x, indices, axis=0)
Array([[nan, nan, nan],
       [ 1.,  2.,  3.]], dtype=float32)
>>> x.at[indices].get(mode='fill', fill_value=jnp.nan)  # equivalent indexing syntax
Array([[nan, nan, nan],
       [ 1.,  2.,  3.]], dtype=float32)

可以使用 mode 參數調整此預設的超出邊界行為,例如,我們可以改為裁剪到最後一個有效值

>>> jnp.take(x, indices, axis=0, mode='clip')
Array([[4., 5., 6.],
       [1., 2., 3.]], dtype=float32)
>>> x.at[indices].get(mode='clip')  # equivalent indexing syntax
Array([[4., 5., 6.],
       [1., 2., 3.]], dtype=float32)