jax.numpy.put_along_axis#

jax.numpy.put_along_axis(arr, indices, values, axis, inplace=True, *, mode=None)[原始碼]#

透過比對 1 維索引和資料切片,將值放入目標陣列。

JAX 實作的 numpy.put_along_axis()

numpy.put_along_axis() 的語意是就地修改陣列,這對於 JAX 的不可變陣列來說是不可能的。JAX 版本會回傳輸入的修改副本,並新增 inplace 參數,使用者必須將其設定為 `False`,以提醒此 API 差異。

參數:
  • arr (ArrayLike) – 值將放入的陣列。

  • indices (ArrayLike) – 要放入值的索引陣列。

  • values (ArrayLike) – 要放入陣列的值陣列。

  • axis (int | None) – 沿著哪個軸放置值。如果未指定,陣列將在套用索引之前展平。

  • inplace (bool) – 必須設定為 False,以表明輸入不會就地修改,而是回傳修改後的副本。

  • mode (str | None) – 越界索引模式。如需關於 mode 選項的更多討論,請參閱 jax.numpy.ndarray.at

回傳:

更新指定條目後的 a 副本。

回傳類型:

陣列

另請參閱

範例

>>> from jax import numpy as jnp
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
>>> i = jnp.argmax(a, axis=1, keepdims=True)
>>> print(i)
[[1]
 [0]]
>>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False)
>>> print(b)
[[10 99 20]
 [99 40 50]]