jax.numpy.mask_indices#
- jax.numpy.mask_indices(n, mask_func, k=0, *, size=None)[原始碼]#
傳回 (n, n) 陣列遮罩的索引。
- 參數:
- 傳回:
mask_func
為非零的索引元組。- 傳回類型:
參見
jax.numpy.triu_indices()
:計算triu()
的mask_indices
。jax.numpy.tril_indices()
:計算tril()
的mask_indices
。
範例
在內建遮罩函數上呼叫
mask_indices
>>> jnp.mask_indices(3, jnp.triu) (Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
>>> jnp.mask_indices(3, jnp.tril) (Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
在自訂遮罩函數上呼叫
mask_indices
>>> def mask_func(x, k=0): ... i = jnp.arange(x.shape[0])[:, None] ... j = jnp.arange(x.shape[1]) ... return (i + 1) % (j + 1 + k) == 0 >>> mask_func(jnp.ones((3, 3))) Array([[ True, False, False], [ True, True, False], [ True, False, True]], dtype=bool) >>> jnp.mask_indices(3, mask_func) (Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32))