jax.numpy.mgrid#

jax.numpy.mgrid = <jax._src.numpy.index_tricks._Mgrid object>#

傳回密集的多元「meshgrid」。

LAX 後端實作的 numpy.mgrid。這是由 jax.numpy.meshgrid() 提供功能的便利包裝函式,其中 sparse=False

參見

jnp.ogrid:jnp.mgrid 的開放/稀疏版本

範例

傳遞 [start:stop:step] 以產生類似於 jax.numpy.arange() 的值

>>> jnp.mgrid[0:4:1]
Array([0, 1, 2, 3], dtype=int32)

傳遞虛數步長會產生類似於 jax.numpy.linspace() 的值

>>> jnp.mgrid[0:1:4j]
Array([0.        , 0.33333334, 0.6666667 , 1.        ], dtype=float32)

多個切片可用於建立廣播的索引網格

>>> jnp.mgrid[:2, :3]
Array([[[0, 0, 0],
        [1, 1, 1]],
       [[0, 1, 2],
        [0, 1, 2]]], dtype=int32)