jax.lax.dynamic_update_index_in_dim#

jax.lax.dynamic_update_index_in_dim(operand, update, index, axis)[原始碼]#

圍繞 dynamic_update_slice() 的便利包裝器,用於在單一 axis 中更新大小為 1 的切片。

參數:
  • operand (Array | np.ndarray) – 要切片的陣列。

  • update (ArrayLike) – 包含要寫入 operand 的新值的陣列。

  • index (ArrayLike) – 單一純量索引

  • axis (int) – 更新的軸。

返回:

更新後的陣列

返回類型:

Array

範例

>>> x = jnp.zeros(6)
>>> y = 1.0
>>> dynamic_update_index_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 0., 0., 0.], dtype=float32)
>>> y = jnp.array([1.0])
>>> dynamic_update_index_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 0., 0., 0.], dtype=float32)

如果指定的索引超出範圍,索引將被裁剪到有效範圍

>>> dynamic_update_index_in_dim(x, y, 10, axis=0)
Array([0., 0., 0., 0., 0., 1.], dtype=float32)

以下是二維動態索引更新的範例

>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones(4)
>>> dynamic_update_index_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
      [1., 1., 1., 1.],
      [0., 0., 0., 0.],
      [0., 0., 0., 0.]], dtype=float32)

請注意,update 中額外軸的形狀不需要與 operand 的相關維度相符

>>> y = jnp.ones((1, 3))
>>> dynamic_update_index_in_dim(x, y, 1, 0)
Array([[0., 0., 0., 0.],
       [1., 1., 1., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)