jax.numpy.histogramdd#
- jax.numpy.histogramdd(sample, bins=10, range=None, weights=None, density=None)[原始碼]#
計算 N 維直方圖。
JAX 版本的
numpy.histogramdd()
。- 參數:
sample (ArrayLike) – 形狀為
(N, D)
的輸入陣列,表示D
維空間中的N
個點。bins (ArrayLike | list[ArrayLike]) – 指定直方圖每個維度的 bin 數量。(預設值:10)。也可以是長度為 D 的整數序列或 bin 邊緣陣列。
range (Sequence[None | Array | Sequence[ArrayLike]] | None | None) – 長度為 D 的配對序列,指定每個維度的範圍。如果未指定,則範圍從資料推斷。
weights (ArrayLike | None | None) – 可選的形狀為
(N,)
的陣列,指定資料點的權重。應與sample
的形狀相同。如果未指定,則每個資料點的權重均等。density (bool | None | None) – 如果為 True,則傳回單位體積計數的標準化直方圖。如果為 False(預設值),則傳回每個 bin 的(加權)計數。
- 傳回值:
陣列元組
(histogram, bin_edges)
,其中histogram
包含聚合資料,而bin_edges
指定 bin 的邊界。- 傳回型別:
另請參閱
jax.numpy.histogram()
:計算一維陣列的直方圖。jax.numpy.histogram2d()
:計算二維陣列的直方圖。jax.numpy.histogram_bin_edges()
:計算直方圖的 bin 邊緣。
範例
三個維度中 100 個點的直方圖
>>> key = jax.random.key(42) >>> a = jax.random.normal(key, (100, 3)) >>> counts, bin_edges = jnp.histogramdd(a, bins=6, ... range=[(-3, 3), (-3, 3), (-3, 3)]) >>> counts.shape (6, 6, 6) >>> bin_edges [Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32)]
使用
density=True
傳回標準化直方圖>>> density, bin_edges = jnp.histogramdd(a, density=True) >>> bin_widths = map(jnp.diff, bin_edges) >>> dx, dy, dz = jnp.meshgrid(*bin_widths, indexing='ij') >>> normed = jnp.sum(density * dx * dy * dz) >>> jnp.allclose(normed, 1.0) Array(True, dtype=bool)