jax.numpy.argwhere#

jax.numpy.argwhere(a, *, size=None, fill_value=None)[原始碼]#

尋找非零陣列元素的索引

numpy.argwhere() 的 JAX 實作。

jnp.argwhere(x) 本質上等同於 jnp.column_stack(jnp.nonzero(x)),並針對零維(即純量)輸入進行特殊處理。

由於 argwhere 輸出的尺寸取決於資料,因此該函數通常與 JIT 不相容。JAX 版本新增了可選的 size 參數,用於靜態指定輸出前導維度的尺寸 - 為了使 jnp.argwhere 能使用非靜態運算元進行編譯,必須靜態指定此參數。有關 size 及其語義的完整討論,請參閱 jax.numpy.nonzero()

參數:
  • a (ArrayLike) – 要尋找非零元素的陣列

  • size (int | None | None) – 可選整數,靜態指定預期的非零元素數量。為了在 JAX 轉換(如 jax.jit())中使用 argwhere,必須指定此參數。更多資訊請參閱 jax.numpy.nonzero()

  • fill_value (ArrayLike | None | None) – 可選陣列,用於在指定 size 時指定填充值。更多資訊請參閱 jax.numpy.nonzero()

傳回:

形狀為 [size, x.ndim] 的二維陣列。如果未將 size 指定為參數,則其等於 x 中非零元素的數量。

傳回類型:

Array

範例

二維陣列

>>> x = jnp.array([[1, 0, 2],
...                [0, 3, 0]])
>>> jnp.argwhere(x)
Array([[0, 0],
       [0, 2],
       [1, 1]], dtype=int32)

使用 jax.numpy.column_stack()jax.numpy.nonzero() 的等效計算

>>> jnp.column_stack(jnp.nonzero(x))
Array([[0, 0],
       [0, 2],
       [1, 1]], dtype=int32)

零維(即純量)輸入的特殊情況

>>> jnp.argwhere(1)
Array([], shape=(1, 0), dtype=int32)
>>> jnp.argwhere(0)
Array([], shape=(0, 0), dtype=int32)