jax.lax.slice_in_dim#
- jax.lax.slice_in_dim(operand, start_index, limit_index, stride=1, axis=0)[原始碼]#
圍繞
lax.slice()
的便利包裝函式,僅適用於一個維度。這實際上等同於
operand[..., start_index:limit_index:stride]
,索引應用於指定的軸。- 參數:
- 返回:
包含切片的陣列。
- 返回類型:
範例
這是一個一維範例
>>> x = jnp.arange(4) >>> lax.slice_in_dim(x, 1, 3) Array([1, 2], dtype=int32)
這是一些二維範例
>>> x = jnp.arange(12).reshape(4, 3) >>> x Array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3) Array([[3, 4, 5], [6, 7, 8]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3, axis=1) Array([[ 1, 2], [ 4, 5], [ 7, 8], [10, 11]], dtype=int32)