jax.numpy.take_along_axis#
- jax.numpy.take_along_axis(arr, indices, axis, mode=None, fill_value=None)[原始碼]#
從陣列中取出元素。
numpy.take_along_axis()
的 JAX 實作,以jax.lax.gather()
實作。JAX 的行為在超出範圍的索引情況下與 NumPy 不同;請參閱下方的mode
參數。- 參數:
a – 要从中取值的陣列。
indices (ArrayLike) – 整數索引陣列。如果
axis
為None
,則必須是一維的。如果axis
不為 None,則必須具有a.ndim == indices.ndim
,且a
必須與indices
在axis
以外的維度上廣播相容。axis (int | None) – 要沿其取值的軸。如果未指定,陣列將在套用索引之前展平。
mode (str | lax.GatherScatterMode | None) – 超出範圍的索引模式,可以是
"fill"
或"clip"
。預設mode="fill"
會為超出範圍的索引傳回無效值(例如 NaN)。如需mode
選項的更多討論,請參閱jax.numpy.ndarray.at
。arr (ArrayLike)
fill_value (StaticScalar | None)
- 傳回:
從
a
提取的值的陣列。- 傳回類型:
參見
jax.numpy.ndarray.at
:透過索引語法取值。jax.numpy.take()
:沿著每個軸切片取相同的索引。
範例
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 6.]]) >>> indices = jnp.array([[0, 2], ... [1, 0]]) >>> jnp.take_along_axis(x, indices, axis=1) Array([[1., 3.], [5., 4.]], dtype=float32) >>> x[jnp.arange(2)[:, None], indices] # equivalent via indexing syntax Array([[1., 3.], [5., 4.]], dtype=float32)
超出範圍的索引會以無效值填滿。對於浮點輸入,這是 NaN
>>> indices = jnp.array([[1, 0, 2]]) >>> jnp.take_along_axis(x, indices, axis=0) Array([[ 4., 2., nan]], dtype=float32) >>> x.at[indices, jnp.arange(3)].get( ... mode='fill', fill_value=jnp.nan) # equivalent via indexing syntax Array([[ 4., 2., nan]], dtype=float32)
take_along_axis
對於從多維 argsort 和 arg 縮減中提取值很有幫助。例如,在這裡我們沿著軸計算argsort()
索引,並使用take_along_axis
建構排序後的陣列>>> x = jnp.array([[5, 3, 4], ... [2, 7, 6]]) >>> indices = jnp.argsort(x, axis=1) >>> indices Array([[1, 2, 0], [0, 2, 1]], dtype=int32) >>> jnp.take_along_axis(x, indices, axis=1) Array([[3, 4, 5], [2, 6, 7]], dtype=int32)
同樣地,我們可以將
argmin()
與keepdims=True
一起使用,並使用take_along_axis
提取最小值>>> idx = jnp.argmin(x, axis=1, keepdims=True) >>> idx Array([[1], [0]], dtype=int32) >>> jnp.take_along_axis(x, idx, axis=1) Array([[3], [2]], dtype=int32)