jax.numpy.take_along_axis#

jax.numpy.take_along_axis(arr, indices, axis, mode=None, fill_value=None)[原始碼]#

從陣列中取出元素。

numpy.take_along_axis() 的 JAX 實作,以 jax.lax.gather() 實作。JAX 的行為在超出範圍的索引情況下與 NumPy 不同;請參閱下方的 mode 參數。

參數:
  • a – 要从中取值的陣列。

  • indices (ArrayLike) – 整數索引陣列。如果 axisNone,則必須是一維的。如果 axis 不為 None,則必須具有 a.ndim == indices.ndim,且 a 必須與 indicesaxis 以外的維度上廣播相容。

  • axis (int | None) – 要沿其取值的軸。如果未指定,陣列將在套用索引之前展平。

  • mode (str | lax.GatherScatterMode | None) – 超出範圍的索引模式,可以是 "fill""clip"。預設 mode="fill" 會為超出範圍的索引傳回無效值(例如 NaN)。如需 mode 選項的更多討論,請參閱 jax.numpy.ndarray.at

  • arr (ArrayLike)

  • fill_value (StaticScalar | None)

傳回:

a 提取的值的陣列。

傳回類型:

陣列

參見

範例

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 6.]])
>>> indices = jnp.array([[0, 2],
...                      [1, 0]])
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[1., 3.],
       [5., 4.]], dtype=float32)
>>> x[jnp.arange(2)[:, None], indices]  # equivalent via indexing syntax
Array([[1., 3.],
       [5., 4.]], dtype=float32)

超出範圍的索引會以無效值填滿。對於浮點輸入,這是 NaN

>>> indices = jnp.array([[1, 0, 2]])
>>> jnp.take_along_axis(x, indices, axis=0)
Array([[ 4.,  2., nan]], dtype=float32)
>>> x.at[indices, jnp.arange(3)].get(
...     mode='fill', fill_value=jnp.nan)  # equivalent via indexing syntax
Array([[ 4.,  2., nan]], dtype=float32)

take_along_axis 對於從多維 argsort 和 arg 縮減中提取值很有幫助。例如,在這裡我們沿著軸計算 argsort() 索引,並使用 take_along_axis 建構排序後的陣列

>>> x = jnp.array([[5, 3, 4],
...                [2, 7, 6]])
>>> indices = jnp.argsort(x, axis=1)
>>> indices
Array([[1, 2, 0],
       [0, 2, 1]], dtype=int32)
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[3, 4, 5],
       [2, 6, 7]], dtype=int32)

同樣地,我們可以將 argmin()keepdims=True 一起使用,並使用 take_along_axis 提取最小值

>>> idx = jnp.argmin(x, axis=1, keepdims=True)
>>> idx
Array([[1],
       [0]], dtype=int32)
>>> jnp.take_along_axis(x, idx, axis=1)
Array([[3],
       [2]], dtype=int32)