jax.lax.dynamic_index_in_dim#
- jax.lax.dynamic_index_in_dim(operand, index, axis=0, keepdims=True)[原始碼]#
dynamic_slice 的便利包裝函式,用於執行整數索引。
這大致等同於沿指定軸套用的以下 Python 索引語法:
operand[..., index]
。- 參數:
- 傳回:
包含切片的陣列。
- 傳回類型:
範例
這是一個一維範例
>>> x = jnp.arange(5) >>> dynamic_index_in_dim(x, 1) Array([1], dtype=int32)
>>> dynamic_index_in_dim(x, 1, keepdims=False) Array(1, dtype=int32)
這是一個二維範例
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_index_in_dim(x, 1, axis=1, keepdims=False) Array([1, 5, 9], dtype=int32)