jax.numpy.place#
- jax.numpy.place(arr, mask, vals, *, inplace=True)[原始碼]#
根據遮罩更新陣列元素。
numpy.place()
的 JAX 實作。numpy.place()
的語義是就地修改陣列,這對於 JAX 的不可變陣列來說是不可能的。JAX 版本會傳回輸入的修改副本,並新增inplace
參數,使用者必須將其設定為 `False`,以提醒此 API 差異。- 參數:
arr (ArrayLike) – 值將放入的陣列。
mask (ArrayLike) – 與
arr
大小相同的布林遮罩。vals (ArrayLike) – 要插入到
arr
中由遮罩指示位置的值。如果提供的值過多,將會被截斷。如果提供的值不足,將會被重複。inplace (bool) – 必須設定為 False,以指示輸入不會就地修改,而是傳回修改後的副本。
- 傳回:
設定了來自 vals 條目的遮罩值的
arr
副本。- 傳回類型:
參見
jax.numpy.put()
:將元素放入數值索引的陣列中。jax.numpy.ndarray.at()
:使用 NumPy 樣式索引的陣列更新
範例
>>> x = jnp.zeros((3, 5), dtype=int) >>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape) >>> mask Array([[ True, False, False, True, False], [False, True, False, False, True], [False, False, True, False, False]], dtype=bool)
放置純量值
>>> jnp.place(x, mask, 1, inplace=False) Array([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=int32)
在此案例中,
jnp.place
類似於遮罩陣列更新語法>>> x.at[mask].set(1) Array([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=int32)
當從陣列放置值時,
place
會有所不同。陣列會重複以填滿遮罩的條目>>> vals = jnp.array([1, 3, 5]) >>> jnp.place(x, mask, vals, inplace=False) Array([[1, 0, 0, 3, 0], [0, 5, 0, 0, 1], [0, 0, 3, 0, 0]], dtype=int32)