jax.numpy.nonzero#
- jax.numpy.nonzero(a, *, size=None, fill_value=None)[原始碼]#
傳回陣列中非零元素的索引。
JAX 版本的
numpy.nonzero()
實作。由於
nonzero
的輸出大小取決於資料,因此此函式與 JIT 和其他轉換不相容。 JAX 版本新增了選用的size
引數,必須靜態指定此引數,才能在 JAX 轉換中使用jnp.nonzero
。- 參數:
- 傳回值:
長度為
a.ndim
的 JAX 陣列的元組,其中包含每個非零值的索引。- 傳回型別:
範例
一維陣列傳回長度為 1 的索引元組
>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jnp.nonzero(x) (Array([1, 3, 5], dtype=int32),)
二維陣列傳回長度為 2 的索引元組
>>> x = jnp.array([[0, 5, 0], ... [6, 0, 7]]) >>> jnp.nonzero(x) (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
在任一情況下,產生的索引元組都可以直接用於擷取非零值
>>> indices = jnp.nonzero(x) >>> x[indices] Array([5, 6, 7], dtype=int32)
nonzero
的輸出具有動態形狀,因為傳回的索引數量取決於輸入陣列的內容。 因此,它與 JIT 和其他 JAX 轉換不相容>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jax.jit(jnp.nonzero)(x) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
這可以透過傳遞靜態
size
參數來指定所需的輸出形狀來解決>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size') >>> nonzero_jit(x, size=3) (Array([1, 3, 5], dtype=int32),)
如果
size
與真實大小不符,則結果將被截斷或填補>>> nonzero_jit(x, size=2) # size < 3: indices are truncated (Array([1, 3], dtype=int32),) >>> nonzero_jit(x, size=5) # size > 3: indices are padded with zeros. (Array([1, 3, 5, 0, 0], dtype=int32),)
您可以使用
fill_value
引數為填補指定自訂填補值>>> nonzero_jit(x, size=5, fill_value=len(x)) (Array([1, 3, 5, 6, 6], dtype=int32),)