jax.numpy.bincount#

jax.numpy.bincount(x, weights=None, minlength=0, *, length=None)[原始碼]#

計算整數陣列中每個值的出現次數。

JAX 版本的 numpy.bincount() 實作。

對於正整數陣列 x,此函式會傳回大小為 x.max() + 1 的陣列 counts,使得 counts[i] 包含值 ix 中出現的次數。

JAX 版本與 NumPy 版本有一些差異

  • 在 NumPy 中,傳遞具有負數項目的陣列 x 會導致錯誤。在 JAX 中,負值會被截斷為零。

  • JAX 新增了選用的 length 參數,可用於靜態指定輸出陣列的長度,以便此函式可以與 jax.jit() 等轉換一起使用。在這種情況下,大於 length + 1 的項目將會被捨棄。

參數:
  • x (ArrayLike) – 正整數的 N 維陣列

  • weights (ArrayLike | None | None) – 與 x 相關聯的選用權重陣列。如果未指定,則每個項目的權重將為 1

  • minlength (int) – 輸出計數陣列的最小長度。

  • length (int | None | None) – 輸出計數陣列的長度。必須靜態指定,才能讓 bincountjax.jit() 和其他 JAX 轉換一起使用。

傳回值:

一個計數或加總權重的陣列,反映 x 中值的出現次數。

傳回型別:

Array

範例

基本 bincount

>>> x = jnp.array([1, 1, 2, 3, 3, 3])
>>> jnp.bincount(x)
Array([0, 2, 1, 3], dtype=int32)

加權 bincount

>>> weights = jnp.array([1, 2, 3, 4, 5, 6])
>>> jnp.bincount(x, weights)
Array([ 0,  3,  3, 15], dtype=int32)

指定靜態 length 使其與 jit 相容

>>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length'])
>>> jit_bincount(x, length=5)
Array([0, 2, 1, 3, 0], dtype=int32)

任何負數都會被截斷到第一個 bin,而超出指定 length 的數字會被捨棄

>>> x = jnp.array([-1, -1, 1, 3, 10])
>>> jnp.bincount(x, length=5)
Array([2, 1, 0, 1, 0], dtype=int32)