jax.numpy.histogram2d#
- jax.numpy.histogram2d(x, y, bins=10, range=None, weights=None, density=None)[原始碼]#
計算二維直方圖。
JAX 實作的
numpy.histogram2d()
。- 參數:
x (ArrayLike) – 要分箱點的一維 x 值陣列。
y (ArrayLike) – 要分箱點的一維 y 值陣列。
bins (ArrayLike | list[ArrayLike]) – 指定直方圖中的箱數 (預設值:10)。
bins
也可以是一個陣列,指定箱邊緣的位置,或一對整數或一對陣列,指定每個維度中的箱數。range (Sequence[None | Array | Sequence[ArrayLike]] | None | None) – 陣列對或列表對,格式為
[[xmin, xmax], [ymin, ymax]]
,指定每個維度中資料的範圍。如果未指定,則從資料推斷範圍。weights (ArrayLike | None | None) – 一個可選的陣列,指定資料點的權重。應與
x
和y
的形狀相同。如果未指定,則每個資料點的權重相等。density (bool | None | None) – 如果為 True,則傳回單位面積計數的正規化直方圖。如果為 False (預設值),則傳回每個箱的 (加權) 計數。
- 傳回值:
陣列的元組
(histogram, x_edges, y_edges)
,其中histogram
包含聚合資料,而x_edges
和y_edges
指定箱的邊界。- 傳回型別:
參見
jax.numpy.histogram()
:計算一維陣列的直方圖。jax.numpy.histogramdd()
:計算 N 維陣列的直方圖。jax.numpy.histogram_bin_edges()
:計算直方圖的箱邊緣。
範例
>>> x = jnp.array([1, 2, 3, 10, 11, 15, 19, 25]) >>> y = jnp.array([2, 5, 6, 8, 13, 16, 17, 18]) >>> counts, x_edges, y_edges = jnp.histogram2d(x, y, bins=8) >>> counts.shape (8, 8) >>> x_edges Array([ 1., 4., 7., 10., 13., 16., 19., 22., 25.], dtype=float32) >>> y_edges Array([ 2., 4., 6., 8., 10., 12., 14., 16., 18.], dtype=float32)
指定箱範圍
>>> counts, x_edges, y_edges = jnp.histogram2d(x, y, range=[(0, 25), (0, 25)], bins=5) >>> counts.shape (5, 5) >>> x_edges Array([ 0., 5., 10., 15., 20., 25.], dtype=float32) >>> y_edges Array([ 0., 5., 10., 15., 20., 25.], dtype=float32)
明確指定箱邊緣
>>> x_edges = jnp.array([0, 10, 20, 30]) >>> y_edges = jnp.array([0, 10, 20, 30]) >>> counts, _, _ = jnp.histogram2d(x, y, bins=[x_edges, y_edges]) >>> counts Array([[3, 0, 0], [1, 3, 0], [0, 1, 0]], dtype=int32)
使用
density=True
傳回正規化直方圖>>> density, x_edges, y_edges = jnp.histogram2d(x, y, density=True) >>> dx = jnp.diff(x_edges) >>> dy = jnp.diff(y_edges) >>> normed_sum = jnp.sum(density * dx[:, None] * dy[None, :]) >>> jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool)