jax.numpy.histogram#

jax.numpy.histogram(a, bins=10, range=None, weights=None, density=None)[原始碼]#

計算一維直方圖。

JAX 實作的 numpy.histogram()

參數:
  • a (ArrayLike) – 要分箱的值陣列。可以是任何大小或維度。

  • bins (ArrayLike) – 指定直方圖中的箱數 (預設值:10)。bins 也可以是指定箱邊位置的陣列。

  • range (Sequence[ArrayLike] | None | None) – 純量元組。指定資料的範圍。如果未指定,則從資料推斷範圍。

  • weights (ArrayLike | None | None) – 一個可選的陣列,用於指定資料點的權重。應與 a 廣播相容。如果未指定,則每個資料點的權重均等。

  • density (bool | None | None) – 如果為 True,則傳回單位長度計數的標準化直方圖。如果為 False (預設值),則傳回每個箱的 (加權) 計數。

傳回:

陣列元組 (histogram, bin_edges),其中 histogram 包含彙總資料,而 bin_edges 指定箱的邊界。

傳回類型:

tuple[Array, Array]

另請參閱

範例

>>> a = jnp.array([1, 2, 3, 10, 11, 15, 19, 25])
>>> counts, bin_edges = jnp.histogram(a, bins=8)
>>> print(counts)
[3. 0. 0. 2. 1. 0. 1. 1.]
>>> print(bin_edges)
[ 1.  4.  7. 10. 13. 16. 19. 22. 25.]

指定箱範圍

>>> counts, bin_edges = jnp.histogram(a, range=(0, 25), bins=5)
>>> print(counts)
[3. 0. 2. 2. 1.]
>>> print(bin_edges)
[ 0.  5. 10. 15. 20. 25.]

明確指定箱邊

>>> bin_edges = jnp.array([0, 10, 20, 30])
>>> counts, _ = jnp.histogram(a, bins=bin_edges)
>>> print(counts)
[3. 4. 1.]

使用 density=True 傳回標準化直方圖

>>> density, bin_edges = jnp.histogram(a, density=True)
>>> dx = jnp.diff(bin_edges)
>>> normed_sum = jnp.sum(density * dx)
>>> jnp.allclose(normed_sum, 1.0)
Array(True, dtype=bool)