jax.numpy.put_along_axis#
- jax.numpy.put_along_axis(arr, indices, values, axis, inplace=True, *, mode=None)[原始碼]#
透過比對 1 維索引和資料切片,將值放入目標陣列。
JAX 實作的
numpy.put_along_axis()
。numpy.put_along_axis()
的語意是就地修改陣列,這對於 JAX 的不可變陣列來說是不可能的。JAX 版本會回傳輸入的修改副本,並新增inplace
參數,使用者必須將其設定為 `False`,以提醒此 API 差異。- 參數:
arr (ArrayLike) – 值將放入的陣列。
indices (ArrayLike) – 要放入值的索引陣列。
values (ArrayLike) – 要放入陣列的值陣列。
axis (int | None) – 沿著哪個軸放置值。如果未指定,陣列將在套用索引之前展平。
inplace (bool) – 必須設定為 False,以表明輸入不會就地修改,而是回傳修改後的副本。
mode (str | None) – 越界索引模式。如需關於
mode
選項的更多討論,請參閱jax.numpy.ndarray.at
。
- 回傳:
更新指定條目後的
a
副本。- 回傳類型:
另請參閱
jax.numpy.put()
:在給定索引處將元素放入陣列。jax.numpy.place()
:透過布林遮罩將元素放入陣列。jax.numpy.ndarray.at()
:使用 NumPy 樣式索引的陣列更新。jax.numpy.take()
:在給定索引處從陣列中提取值。jax.numpy.take_along_axis()
:沿軸從陣列中提取值。
範例
>>> from jax import numpy as jnp >>> a = jnp.array([[10, 30, 20], [60, 40, 50]]) >>> i = jnp.argmax(a, axis=1, keepdims=True) >>> print(i) [[1] [0]] >>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False) >>> print(b) [[10 99 20] [99 40 50]]