jax.numpy.mask_indices#

jax.numpy.mask_indices(n, mask_func, k=0, *, size=None)[原始碼]#

傳回 (n, n) 陣列遮罩的索引。

參數:
  • n (int) – 靜態整數陣列維度。

  • mask_func (Callable[[ArrayLike, int], Array]) – 一個函數,它接受形狀為 (n, n) 的陣列和一個可選的偏移量 k,並傳回形狀為 (n, n) 的遮罩。具有此簽名的函數範例為 triu()tril()

  • k (int) – 傳遞給 mask_func 的純量值。

  • size (int | None | None) – 可選參數,指定輸出陣列的靜態大小。這會在從遮罩產生索引時傳遞給 nonzero()

傳回:

mask_func 為非零的索引元組。

傳回類型:

tuple[Array, Array]

參見

範例

在內建遮罩函數上呼叫 mask_indices

>>> jnp.mask_indices(3, jnp.triu)
(Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
>>> jnp.mask_indices(3, jnp.tril)
(Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))

在自訂遮罩函數上呼叫 mask_indices

>>> def mask_func(x, k=0):
...   i = jnp.arange(x.shape[0])[:, None]
...   j = jnp.arange(x.shape[1])
...   return (i + 1) % (j + 1 + k) == 0
>>> mask_func(jnp.ones((3, 3)))
Array([[ True, False, False],
       [ True,  True, False],
       [ True, False,  True]], dtype=bool)
>>> jnp.mask_indices(3, mask_func)
(Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32))