jax.lax.dynamic_update_index_in_dim#
- jax.lax.dynamic_update_index_in_dim(operand, update, index, axis)[原始碼]#
圍繞
dynamic_update_slice()
的便利包裝器,用於在單一axis
中更新大小為 1 的切片。- 參數:
- 返回:
更新後的陣列
- 返回類型:
範例
>>> x = jnp.zeros(6) >>> y = 1.0 >>> dynamic_update_index_in_dim(x, y, 2, axis=0) Array([0., 0., 1., 0., 0., 0.], dtype=float32)
>>> y = jnp.array([1.0]) >>> dynamic_update_index_in_dim(x, y, 2, axis=0) Array([0., 0., 1., 0., 0., 0.], dtype=float32)
如果指定的索引超出範圍,索引將被裁剪到有效範圍
>>> dynamic_update_index_in_dim(x, y, 10, axis=0) Array([0., 0., 0., 0., 0., 1.], dtype=float32)
以下是二維動態索引更新的範例
>>> x = jnp.zeros((4, 4)) >>> y = jnp.ones(4) >>> dynamic_update_index_in_dim(x, y, 1, axis=0) Array([[0., 0., 0., 0.], [1., 1., 1., 1.], [0., 0., 0., 0.], [0., 0., 0., 0.]], dtype=float32)
請注意,
update
中額外軸的形狀不需要與operand
的相關維度相符>>> y = jnp.ones((1, 3)) >>> dynamic_update_index_in_dim(x, y, 1, 0) Array([[0., 0., 0., 0.], [1., 1., 1., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], dtype=float32)