jax.numpy.count_nonzero#
- jax.numpy.count_nonzero(a, axis=None, keepdims=False)[source]#
沿著給定軸返回非零元素的數量。
numpy.count_nonzero()
的 JAX 實作。- 參數:
a (ArrayLike) – 輸入陣列。
axis (Axis) – 選項,整數或整數序列,預設值=None。 沿著其計算非零數量的軸。 如果為 None,則計算展平陣列內的數量。
keepdims (bool) – 布林值,預設值=False。 如果為 true,縮減軸會保留在結果中,大小為 1。
- 返回值:
一個陣列,其中包含沿輸入指定軸的非零元素數量。
- 返回類型:
範例
預設情況下,
jnp.count_nonzero
會計算沿所有軸的非零值。>>> x = jnp.array([[1, 0, 0, 0], ... [0, 0, 1, 0], ... [1, 1, 1, 0]]) >>> jnp.count_nonzero(x) Array(5, dtype=int32)
如果
axis=1
,則沿軸 1 計數。>>> jnp.count_nonzero(x, axis=1) Array([1, 1, 3], dtype=int32)
若要保留輸入的維度,您可以設定
keepdims=True
。>>> jnp.count_nonzero(x, axis=1, keepdims=True) Array([[1], [1], [3]], dtype=int32)