jax.lax.dynamic_slice_in_dim#
- jax.lax.dynamic_slice_in_dim(operand, start_index, slice_size, axis=0)[原始碼]#
圍繞
lax.dynamic_slice()
的便利包裝函式,應用於一個維度。這大致等同於沿指定軸應用的以下 Python 索引語法:
operand[..., start_index:start_index + slice_size]
。- 參數:
- 回傳:
包含切片的陣列。
- 回傳型別:
範例
這是一個一維範例
>>> x = jnp.arange(5) >>> dynamic_slice_in_dim(x, 1, 3) Array([1, 2, 3], dtype=int32)
與 jax.lax.dynamic_slice 類似,超出範圍的切片將被裁剪到有效範圍
>>> dynamic_slice_in_dim(x, 4, 3) Array([2, 3, 4], dtype=int32)
這是一個二維範例
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice_in_dim(x, 1, 2, axis=1) Array([[ 1, 2], [ 5, 6], [ 9, 10]], dtype=int32)