jax.numpy.nonzero#

jax.numpy.nonzero(a, *, size=None, fill_value=None)[原始碼]#

傳回陣列中非零元素的索引。

JAX 版本的 numpy.nonzero() 實作。

由於 nonzero 的輸出大小取決於資料,因此此函式與 JIT 和其他轉換不相容。 JAX 版本新增了選用的 size 引數,必須靜態指定此引數,才能在 JAX 轉換中使用 jnp.nonzero

參數:
  • a (ArrayLike) – N 維陣列。

  • size (int | None | None) – 選用的靜態整數,指定要傳回的非零條目數量。 如果非零元素多於指定的 size,則索引將在末尾截斷。 如果非零元素少於指定的大小,則索引將以 fill_value 填補,預設值為零。

  • fill_value (None | ArrayLike | tuple[ArrayLike, ...] | None) – 指定 size 時選用的填補值。 預設值為 0。

傳回值:

長度為 a.ndim 的 JAX 陣列的元組,其中包含每個非零值的索引。

傳回型別:

tuple[Array, …]

範例

一維陣列傳回長度為 1 的索引元組

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jnp.nonzero(x)
(Array([1, 3, 5], dtype=int32),)

二維陣列傳回長度為 2 的索引元組

>>> x = jnp.array([[0, 5, 0],
...                [6, 0, 7]])
>>> jnp.nonzero(x)
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))

在任一情況下,產生的索引元組都可以直接用於擷取非零值

>>> indices = jnp.nonzero(x)
>>> x[indices]
Array([5, 6, 7], dtype=int32)

nonzero 的輸出具有動態形狀,因為傳回的索引數量取決於輸入陣列的內容。 因此,它與 JIT 和其他 JAX 轉換不相容

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jax.jit(jnp.nonzero)(x)  
Traceback (most recent call last):
  ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.

這可以透過傳遞靜態 size 參數來指定所需的輸出形狀來解決

>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size')
>>> nonzero_jit(x, size=3)
(Array([1, 3, 5], dtype=int32),)

如果 size 與真實大小不符,則結果將被截斷或填補

>>> nonzero_jit(x, size=2)  # size < 3: indices are truncated
(Array([1, 3], dtype=int32),)
>>> nonzero_jit(x, size=5)  # size > 3: indices are padded with zeros.
(Array([1, 3, 5, 0, 0], dtype=int32),)

您可以使用 fill_value 引數為填補指定自訂填補值

>>> nonzero_jit(x, size=5, fill_value=len(x))
(Array([1, 3, 5, 6, 6], dtype=int32),)