jax.numpy.put#
- jax.numpy.put(a, ind, v, mode=None, *, inplace=True)[原始碼]#
將元素放入給定索引的陣列中。
JAX 版本的
numpy.put()
實作。numpy.put()
的語意是就地修改陣列,這對於 JAX 的不可變陣列來說是不可能的。JAX 版本會傳回輸入的修改副本,並加入inplace
參數,使用者必須將其設定為 False,以提醒此 API 差異。- 參數:
- 傳回:
a
的副本,其中指定的條目已更新。- 傳回型別:
另請參閱
jax.numpy.place()
:透過布林遮罩將元素放入陣列中。jax.numpy.ndarray.at()
:使用 NumPy 風格索引的陣列更新。jax.numpy.take()
:從給定索引的陣列中提取值。
範例
>>> x = jnp.zeros(5, dtype=int) >>> indices = jnp.array([0, 2, 4]) >>> values = jnp.array([10, 20, 30]) >>> jnp.put(x, indices, values, inplace=False) Array([10, 0, 20, 0, 30], dtype=int32)
這等效於以下的
jax.numpy.ndarray.at
索引語法>>> x.at[indices].set(values) Array([10, 0, 20, 0, 30], dtype=int32)
有兩種模式可以處理超出邊界的索引。預設情況下,它們會被裁剪
>>> indices = jnp.array([0, 2, 6]) >>> jnp.put(x, indices, values, inplace=False, mode='clip') Array([10, 0, 20, 0, 30], dtype=int32)
或者,它們可以被包裝到陣列的開頭
>>> jnp.put(x, indices, values, inplace=False, mode='wrap') Array([10, 30, 20, 0, 0], dtype=int32)
對於 N 維輸入,索引指向扁平化陣列
>>> x = jnp.zeros((3, 5), dtype=int) >>> indices = jnp.array([0, 7, 14]) >>> jnp.put(x, indices, values, inplace=False) Array([[10, 0, 0, 0, 0], [ 0, 0, 20, 0, 0], [ 0, 0, 0, 0, 30]], dtype=int32)