jax.numpy.indices#
- jax.numpy.indices(dimensions, dtype=None, sparse=False)[原始碼]#
產生網格索引陣列。
JAX 實作的
numpy.indices()
。- 參數:
- 傳回:
形狀為
(len(dimensions), *dimensions)
的陣列。若sparse
為 False,或與dimensions
長度相同的陣列序列。若sparse
為 True。- 傳回型別:
另請參閱
jax.numpy.meshgrid()
:從任意輸入陣列產生網格。jax.numpy.mgrid
:使用切片語法產生密集索引。jax.numpy.ogrid
:使用切片語法產生稀疏索引。
範例
>>> jnp.indices((2, 3)) Array([[[0, 0, 0], [1, 1, 1]], [[0, 1, 2], [0, 1, 2]]], dtype=int32) >>> jnp.indices((2, 3), sparse=True) (Array([[0], [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32))