jax.numpy.take#
- jax.numpy.take(a, indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[原始碼]#
從陣列中取出元素。
numpy.take()
的 JAX 實作,以jax.lax.gather()
實作。mode
參數如下,JAX 的行為在超出邊界索引的情況下與 NumPy 不同;請參閱下方的mode
參數。- 參數::
a (ArrayLike) – 要从中取值的陣列。
indices (ArrayLike) – 從陣列中取值的整數索引 N 維陣列。
axis (int | None | None) – 沿著取值的軸。如果未指定,陣列將在套用索引之前展平。
mode (str | None | None) – 超出邊界索引模式,可以是
"fill"
或"clip"
。預設的mode="fill"
會為超出邊界的索引傳回無效值 (例如 NaN);fill_value
參數可控制此值。有關mode
選項的更多討論,請參閱jax.numpy.ndarray.at
。fill_value (StaticScalar | None | None) – 當 mode 為 ‘fill’ 時,針對超出邊界切片傳回的填充值。否則將忽略。對於非精確類型,預設為 NaN;對於帶符號類型,預設為最大負值;對於無符號類型,預設為最大正值;對於布林值,預設為 True。
unique_indices (bool) – 如果為 True,則實作會假設索引是唯一的,這可以在某些後端產生更有效率的執行。如果設定為 True 且索引不是唯一的,則輸出是未定義的。
indices_are_sorted (bool) – 如果為 True,則實作會假設索引是升序排序的,這可以在某些後端產生更有效率的執行。如果設定為 True 且索引未排序,則輸出是未定義的。
out (None | None)
- 返回::
從
a
提取的值的陣列。- 返回類型::
參見
jax.numpy.ndarray.at
:透過索引語法取值。
範例
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 6.]]) >>> indices = jnp.array([2, 0])
不傳遞軸會導致索引到展平的陣列中
>>> jnp.take(x, indices) Array([3., 1.], dtype=float32) >>> x.ravel()[indices] # equivalent indexing syntax Array([3., 1.], dtype=float32)
傳遞軸會導致將索引應用於沿軸的每個子陣列
>>> jnp.take(x, indices, axis=1) Array([[3., 1.], [6., 4.]], dtype=float32) >>> x[:, indices] # equivalent indexing syntax Array([[3., 1.], [6., 4.]], dtype=float32)
超出邊界的索引會以無效值填充。對於浮點輸入,這是 NaN
>>> jnp.take(x, indices, axis=0) Array([[nan, nan, nan], [ 1., 2., 3.]], dtype=float32) >>> x.at[indices].get(mode='fill', fill_value=jnp.nan) # equivalent indexing syntax Array([[nan, nan, nan], [ 1., 2., 3.]], dtype=float32)
可以使用
mode
參數調整此預設的超出邊界行為,例如,我們可以改為裁剪到最後一個有效值>>> jnp.take(x, indices, axis=0, mode='clip') Array([[4., 5., 6.], [1., 2., 3.]], dtype=float32) >>> x.at[indices].get(mode='clip') # equivalent indexing syntax Array([[4., 5., 6.], [1., 2., 3.]], dtype=float32)