jax.numpy.put#

jax.numpy.put(a, ind, v, mode=None, *, inplace=True)[原始碼]#

將元素放入給定索引的陣列中。

JAX 版本的 numpy.put() 實作。

numpy.put() 的語意是就地修改陣列,這對於 JAX 的不可變陣列來說是不可能的。JAX 版本會傳回輸入的修改副本,並加入 inplace 參數,使用者必須將其設定為 False,以提醒此 API 差異。

參數:
  • a (ArrayLike) – 將值放入的陣列。

  • ind (ArrayLike) – 索引陣列,指向要放入值的扁平化陣列。

  • v (ArrayLike) – 要放入陣列的值陣列。

  • mode (str | None | None) –

    字串,指定如何處理超出邊界的索引。支援的值

    • "clip" (預設):將超出邊界的索引裁剪到最後一個索引。

    • "wrap":將超出邊界的索引包裝到陣列的開頭。

  • inplace (bool) – 必須設定為 False,以表明輸入不會就地修改,而是傳回修改後的副本。

傳回:

a 的副本,其中指定的條目已更新。

傳回型別:

Array

另請參閱

範例

>>> x = jnp.zeros(5, dtype=int)
>>> indices = jnp.array([0, 2, 4])
>>> values = jnp.array([10, 20, 30])
>>> jnp.put(x, indices, values, inplace=False)
Array([10,  0, 20,  0, 30], dtype=int32)

這等效於以下的 jax.numpy.ndarray.at 索引語法

>>> x.at[indices].set(values)
Array([10,  0, 20,  0, 30], dtype=int32)

有兩種模式可以處理超出邊界的索引。預設情況下,它們會被裁剪

>>> indices = jnp.array([0, 2, 6])
>>> jnp.put(x, indices, values, inplace=False, mode='clip')
Array([10,  0, 20,  0, 30], dtype=int32)

或者,它們可以被包裝到陣列的開頭

>>> jnp.put(x, indices, values, inplace=False, mode='wrap')
Array([10,  30, 20,  0, 0], dtype=int32)

對於 N 維輸入,索引指向扁平化陣列

>>> x = jnp.zeros((3, 5), dtype=int)
>>> indices = jnp.array([0, 7, 14])
>>> jnp.put(x, indices, values, inplace=False)
Array([[10,  0,  0,  0,  0],
       [ 0,  0, 20,  0,  0],
       [ 0,  0,  0,  0, 30]], dtype=int32)