jax.lax.dynamic_update_slice#

jax.lax.dynamic_update_slice(operand, update, start_indices)[source]#

包裝 XLA 的 DynamicUpdateSlice 運算子。

參數:
  • operand (Array | np.ndarray) – 要切片的陣列。

  • update (ArrayLike) – 包含要寫入 operand 的新值的陣列。

  • start_indices (Array | Sequence[ArrayLike]) – 純量索引的列表,每個維度一個。

回傳:

包含切片的陣列。

回傳型別:

Array

範例

以下是如何更新一維切片更新的範例

>>> x = jnp.zeros(6)
>>> y = jnp.ones(3)
>>> dynamic_update_slice(x, y, (2,))
Array([0., 0., 1., 1., 1., 0.], dtype=float32)

如果更新切片太大而無法放入陣列,則會調整起始索引以使其符合

>>> dynamic_update_slice(x, y, (3,))
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
>>> dynamic_update_slice(x, y, (5,))
Array([0., 0., 0., 1., 1., 1.], dtype=float32)

以下是如何更新二維切片更新的範例

>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones((2, 2))
>>> dynamic_update_slice(x, y, (1, 2))
Array([[0., 0., 0., 0.],
       [0., 0., 1., 1.],
       [0., 0., 1., 1.],
       [0., 0., 0., 0.]], dtype=float32)

另請參閱