jax.numpy.count_nonzero#

jax.numpy.count_nonzero(a, axis=None, keepdims=False)[source]#

沿著給定軸返回非零元素的數量。

numpy.count_nonzero() 的 JAX 實作。

參數:
  • a (ArrayLike) – 輸入陣列。

  • axis (Axis) – 選項,整數或整數序列,預設值=None。 沿著其計算非零數量的軸。 如果為 None,則計算展平陣列內的數量。

  • keepdims (bool) – 布林值,預設值=False。 如果為 true,縮減軸會保留在結果中,大小為 1。

返回值:

一個陣列,其中包含沿輸入指定軸的非零元素數量。

返回類型:

Array

範例

預設情況下,jnp.count_nonzero 會計算沿所有軸的非零值。

>>> x = jnp.array([[1, 0, 0, 0],
...                [0, 0, 1, 0],
...                [1, 1, 1, 0]])
>>> jnp.count_nonzero(x)
Array(5, dtype=int32)

如果 axis=1,則沿軸 1 計數。

>>> jnp.count_nonzero(x, axis=1)
Array([1, 1, 3], dtype=int32)

若要保留輸入的維度,您可以設定 keepdims=True

>>> jnp.count_nonzero(x, axis=1, keepdims=True)
Array([[1],
       [1],
       [3]], dtype=int32)