jax.numpy.mgrid#
- jax.numpy.mgrid = <jax._src.numpy.index_tricks._Mgrid object>#
傳回密集的多元「meshgrid」。
LAX 後端實作的
numpy.mgrid
。這是由jax.numpy.meshgrid()
提供功能的便利包裝函式,其中sparse=False
。參見
jnp.ogrid:jnp.mgrid 的開放/稀疏版本
範例
傳遞
[start:stop:step]
以產生類似於jax.numpy.arange()
的值>>> jnp.mgrid[0:4:1] Array([0, 1, 2, 3], dtype=int32)
傳遞虛數步長會產生類似於
jax.numpy.linspace()
的值>>> jnp.mgrid[0:1:4j] Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
多個切片可用於建立廣播的索引網格
>>> jnp.mgrid[:2, :3] Array([[[0, 0, 0], [1, 1, 1]], [[0, 1, 2], [0, 1, 2]]], dtype=int32)