jax.numpy.meshgrid#

jax.numpy.meshgrid(*xi, copy=True, sparse=False, indexing='xy')[原始碼]#

從 N 個一維向量建構 N 維網格陣列。

numpy.meshgrid() 的 JAX 實作。

參數:
  • xi (ArrayLike) – 要轉換為網格的 N 個陣列。

  • copy (bool) – 是否複製輸入陣列。JAX 僅支援 copy=True,但在 JIT 編譯下,編譯器可能會選擇避免複製。

  • sparse (bool) – 若為 False (預設),則每個傳回的陣列形狀為 [len(x1), len(x2), ..., len(xN)]。若為 False,則傳回的陣列形狀為 [1, 1, ..., len(xi), ..., 1, 1]

  • indexing (str) – 選項為 'xy' 代表笛卡爾索引 (預設) 或 'ij' 代表矩陣索引。

傳回:

長度為 N 的網格陣列列表。

傳回型別:

list[Array]

另請參閱

範例

對於以下範例,我們將使用這些一維陣列作為輸入

>>> x = jnp.array([1, 2])
>>> y = jnp.array([10, 20, 30])

2D 笛卡爾網格

>>> x_grid, y_grid = jnp.meshgrid(x, y)
>>> print(x_grid)
[[1 2]
 [1 2]
 [1 2]]
>>> print(y_grid)
[[10 10]
 [20 20]
 [30 30]]

2D 稀疏笛卡爾網格

>>> x_grid, y_grid = jnp.meshgrid(x, y, sparse=True)
>>> print(x_grid)
[[1 2]]
>>> print(y_grid)
[[10]
 [20]
 [30]]

2D 矩陣索引網格

>>> x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij')
>>> print(x_grid)
[[1 1 1]
 [2 2 2]]
>>> print(y_grid)
[[10 20 30]
 [10 20 30]]