jax.numpy.place#

jax.numpy.place(arr, mask, vals, *, inplace=True)[原始碼]#

根據遮罩更新陣列元素。

numpy.place() 的 JAX 實作。

numpy.place() 的語義是就地修改陣列,這對於 JAX 的不可變陣列來說是不可能的。JAX 版本會傳回輸入的修改副本,並新增 inplace 參數,使用者必須將其設定為 `False`,以提醒此 API 差異。

參數:
  • arr (ArrayLike) – 值將放入的陣列。

  • mask (ArrayLike) – 與 arr 大小相同的布林遮罩。

  • vals (ArrayLike) – 要插入到 arr 中由遮罩指示位置的值。如果提供的值過多,將會被截斷。如果提供的值不足,將會被重複。

  • inplace (bool) – 必須設定為 False,以指示輸入不會就地修改,而是傳回修改後的副本。

傳回:

設定了來自 vals 條目的遮罩值的 arr 副本。

傳回類型:

Array

參見

範例

>>> x = jnp.zeros((3, 5), dtype=int)
>>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape)
>>> mask
Array([[ True, False, False,  True, False],
       [False,  True, False, False,  True],
       [False, False,  True, False, False]], dtype=bool)

放置純量值

>>> jnp.place(x, mask, 1, inplace=False)
Array([[1, 0, 0, 1, 0],
       [0, 1, 0, 0, 1],
       [0, 0, 1, 0, 0]], dtype=int32)

在此案例中,jnp.place 類似於遮罩陣列更新語法

>>> x.at[mask].set(1)
Array([[1, 0, 0, 1, 0],
       [0, 1, 0, 0, 1],
       [0, 0, 1, 0, 0]], dtype=int32)

當從陣列放置值時,place 會有所不同。陣列會重複以填滿遮罩的條目

>>> vals = jnp.array([1, 3, 5])
>>> jnp.place(x, mask, vals, inplace=False)
Array([[1, 0, 0, 3, 0],
       [0, 5, 0, 0, 1],
       [0, 0, 3, 0, 0]], dtype=int32)