jax.lax.slice_in_dim#

jax.lax.slice_in_dim(operand, start_index, limit_index, stride=1, axis=0)[原始碼]#

圍繞 lax.slice() 的便利包裝函式,僅適用於一個維度。

這實際上等同於 operand[..., start_index:limit_index:stride],索引應用於指定的軸。

參數:
  • operand (Array | np.ndarray) – 要切片的陣列。

  • start_index (int | None) – 可選的開始索引 (預設為零)

  • limit_index (int | None) – 可選的結束索引 (預設為 operand.shape[axis])

  • stride (int) – 可選的步幅 (預設為 1)

  • axis (int) – 應用切片的軸 (預設為 0)

返回:

包含切片的陣列。

返回類型:

Array

範例

這是一個一維範例

>>> x = jnp.arange(4)
>>> lax.slice_in_dim(x, 1, 3)
Array([1, 2], dtype=int32)

這是一些二維範例

>>> x = jnp.arange(12).reshape(4, 3)
>>> x
Array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3)
Array([[3, 4, 5],
       [6, 7, 8]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3, axis=1)
Array([[ 1,  2],
       [ 4,  5],
       [ 7,  8],
       [10, 11]], dtype=int32)