jax.numpy.argwhere#
- jax.numpy.argwhere(a, *, size=None, fill_value=None)[原始碼]#
尋找非零陣列元素的索引
numpy.argwhere() 的 JAX 實作。
jnp.argwhere(x)
本質上等同於jnp.column_stack(jnp.nonzero(x))
,並針對零維(即純量)輸入進行特殊處理。由於
argwhere
輸出的尺寸取決於資料,因此該函數通常與 JIT 不相容。JAX 版本新增了可選的size
參數,用於靜態指定輸出前導維度的尺寸 - 為了使jnp.argwhere
能使用非靜態運算元進行編譯,必須靜態指定此參數。有關size
及其語義的完整討論,請參閱jax.numpy.nonzero()
。- 參數:
a (ArrayLike) – 要尋找非零元素的陣列
size (int | None | None) – 可選整數,靜態指定預期的非零元素數量。為了在 JAX 轉換(如
jax.jit()
)中使用argwhere
,必須指定此參數。更多資訊請參閱jax.numpy.nonzero()
。fill_value (ArrayLike | None | None) – 可選陣列,用於在指定
size
時指定填充值。更多資訊請參閱jax.numpy.nonzero()
。
- 傳回:
形狀為
[size, x.ndim]
的二維陣列。如果未將size
指定為參數,則其等於x
中非零元素的數量。- 傳回類型:
範例
二維陣列
>>> x = jnp.array([[1, 0, 2], ... [0, 3, 0]]) >>> jnp.argwhere(x) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
使用
jax.numpy.column_stack()
和jax.numpy.nonzero()
的等效計算>>> jnp.column_stack(jnp.nonzero(x)) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
零維(即純量)輸入的特殊情況
>>> jnp.argwhere(1) Array([], shape=(1, 0), dtype=int32) >>> jnp.argwhere(0) Array([], shape=(0, 0), dtype=int32)