jax.numpy.indices#

jax.numpy.indices(dimensions, dtype=None, sparse=False)[原始碼]#

產生網格索引陣列。

JAX 實作的 numpy.indices()

參數:
  • dimensions (Sequence[int]) – 網格的形狀。

  • dtype (DTypeLike | None | None) – 索引的 dtype(預設為整數)。

  • sparse (bool) – 若為 True,則傳回稀疏索引。預設值為 False,傳回密集索引。

傳回:

形狀為 (len(dimensions), *dimensions) 的陣列。若 sparse 為 False,或與 dimensions 長度相同的陣列序列。若 sparse 為 True。

傳回型別:

Array | tuple[Array, …]

另請參閱

範例

>>> jnp.indices((2, 3))
Array([[[0, 0, 0],
        [1, 1, 1]],

       [[0, 1, 2],
        [0, 1, 2]]], dtype=int32)
>>> jnp.indices((2, 3), sparse=True)
(Array([[0],
       [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32))