jax.numpy.histogramdd#

jax.numpy.histogramdd(sample, bins=10, range=None, weights=None, density=None)[原始碼]#

計算 N 維直方圖。

JAX 版本的 numpy.histogramdd()

參數:
  • sample (ArrayLike) – 形狀為 (N, D) 的輸入陣列,表示 D 維空間中的 N 個點。

  • bins (ArrayLike | list[ArrayLike]) – 指定直方圖每個維度的 bin 數量。(預設值:10)。也可以是長度為 D 的整數序列或 bin 邊緣陣列。

  • range (Sequence[None | Array | Sequence[ArrayLike]] | None | None) – 長度為 D 的配對序列,指定每個維度的範圍。如果未指定,則範圍從資料推斷。

  • weights (ArrayLike | None | None) – 可選的形狀為 (N,) 的陣列,指定資料點的權重。應與 sample 的形狀相同。如果未指定,則每個資料點的權重均等。

  • density (bool | None | None) – 如果為 True,則傳回單位體積計數的標準化直方圖。如果為 False(預設值),則傳回每個 bin 的(加權)計數。

傳回值:

陣列元組 (histogram, bin_edges),其中 histogram 包含聚合資料,而 bin_edges 指定 bin 的邊界。

傳回型別:

tuple[Array, list[Array]]

另請參閱

範例

三個維度中 100 個點的直方圖

>>> key = jax.random.key(42)
>>> a = jax.random.normal(key, (100, 3))
>>> counts, bin_edges = jnp.histogramdd(a, bins=6,
...                                     range=[(-3, 3), (-3, 3), (-3, 3)])
>>> counts.shape
(6, 6, 6)
>>> bin_edges  
[Array([-3., -2., -1.,  0.,  1.,  2.,  3.], dtype=float32),
 Array([-3., -2., -1.,  0.,  1.,  2.,  3.], dtype=float32),
 Array([-3., -2., -1.,  0.,  1.,  2.,  3.], dtype=float32)]

使用 density=True 傳回標準化直方圖

>>> density, bin_edges = jnp.histogramdd(a, density=True)
>>> bin_widths = map(jnp.diff, bin_edges)
>>> dx, dy, dz = jnp.meshgrid(*bin_widths, indexing='ij')
>>> normed = jnp.sum(density * dx * dy * dz)
>>> jnp.allclose(normed, 1.0)
Array(True, dtype=bool)