jax.numpy.bincount#
- jax.numpy.bincount(x, weights=None, minlength=0, *, length=None)[原始碼]#
計算整數陣列中每個值的出現次數。
JAX 版本的
numpy.bincount()
實作。對於正整數陣列
x
,此函式會傳回大小為x.max() + 1
的陣列counts
,使得counts[i]
包含值i
在x
中出現的次數。JAX 版本與 NumPy 版本有一些差異
在 NumPy 中,傳遞具有負數項目的陣列
x
會導致錯誤。在 JAX 中,負值會被截斷為零。JAX 新增了選用的
length
參數,可用於靜態指定輸出陣列的長度,以便此函式可以與jax.jit()
等轉換一起使用。在這種情況下,大於 length + 1 的項目將會被捨棄。
- 參數:
- 傳回值:
一個計數或加總權重的陣列,反映
x
中值的出現次數。- 傳回型別:
範例
基本 bincount
>>> x = jnp.array([1, 1, 2, 3, 3, 3]) >>> jnp.bincount(x) Array([0, 2, 1, 3], dtype=int32)
加權 bincount
>>> weights = jnp.array([1, 2, 3, 4, 5, 6]) >>> jnp.bincount(x, weights) Array([ 0, 3, 3, 15], dtype=int32)
指定靜態
length
使其與 jit 相容>>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length']) >>> jit_bincount(x, length=5) Array([0, 2, 1, 3, 0], dtype=int32)
任何負數都會被截斷到第一個 bin,而超出指定
length
的數字會被捨棄>>> x = jnp.array([-1, -1, 1, 3, 10]) >>> jnp.bincount(x, length=5) Array([2, 1, 0, 1, 0], dtype=int32)