jax.lax.index_in_dim#

jax.lax.index_in_dim(operand, index, axis=0, keepdims=True)[原始碼]#

圍繞 lax.slice() 的便利包裝器,用於執行整數索引。

這實際上等同於 operand[..., start_index:limit_index:stride],索引應用於指定的軸。

參數:
  • operand (Array | np.ndarray) – 要索引的陣列。

  • index (int) – 整數索引

  • axis (int) – 應用索引的軸 (預設為 0)

  • keepdims (bool) – 布林值,指定輸出陣列是否應保留輸入的秩 (預設值=True)

傳回:

指定索引處的子陣列。

傳回類型:

Array

範例

這是一個一維範例

>>> x = jnp.arange(4)
>>> lax.index_in_dim(x, 2)
Array([2], dtype=int32)
>>> lax.index_in_dim(x, 2, keepdims=False)
Array(2, dtype=int32)

這是一些二維範例

>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)
>>> lax.index_in_dim(x, 1)
Array([[4, 5, 6, 7]], dtype=int32)
>>> lax.index_in_dim(x, 1, axis=1, keepdims=False)
Array([1, 5, 9], dtype=int32)