jax.numpy.ix_#

jax.numpy.ix_(*args)[原始碼]#

從 N 個一維序列傳回多維網格 (開放網格)。

numpy.ix_() 的 JAX 實作。

參數:

*args (ArrayLike) – N 個一維陣列

傳回:

形成開放網格的 Jax 陣列元組,每個陣列都有 N 維度。

傳回類型:

tuple[Array, …]

範例

>>> rows = jnp.array([0, 2])
>>> cols = jnp.array([1, 3])
>>> open_mesh = jnp.ix_(rows, cols)
>>> open_mesh
(Array([[0],
      [2]], dtype=int32), Array([[1, 3]], dtype=int32))
>>> [grid.shape for grid in open_mesh]
[(2, 1), (1, 2)]
>>> x = jnp.array([[10, 20, 30, 40],
...                [50, 60, 70, 80],
...                [90, 100, 110, 120],
...                [130, 140, 150, 160]])
>>> x[open_mesh]
Array([[ 20,  40],
       [100, 120]], dtype=int32)