jax.numpy.histogram2d#

jax.numpy.histogram2d(x, y, bins=10, range=None, weights=None, density=None)[原始碼]#

計算二維直方圖。

JAX 實作的 numpy.histogram2d()

參數:
  • x (ArrayLike) – 要分箱點的一維 x 值陣列。

  • y (ArrayLike) – 要分箱點的一維 y 值陣列。

  • bins (ArrayLike | list[ArrayLike]) – 指定直方圖中的箱數 (預設值:10)。bins 也可以是一個陣列,指定箱邊緣的位置,或一對整數或一對陣列,指定每個維度中的箱數。

  • range (Sequence[None | Array | Sequence[ArrayLike]] | None | None) – 陣列對或列表對,格式為 [[xmin, xmax], [ymin, ymax]],指定每個維度中資料的範圍。如果未指定,則從資料推斷範圍。

  • weights (ArrayLike | None | None) – 一個可選的陣列,指定資料點的權重。應與 xy 的形狀相同。如果未指定,則每個資料點的權重相等。

  • density (bool | None | None) – 如果為 True,則傳回單位面積計數的正規化直方圖。如果為 False (預設值),則傳回每個箱的 (加權) 計數。

傳回值:

陣列的元組 (histogram, x_edges, y_edges),其中 histogram 包含聚合資料,而 x_edgesy_edges 指定箱的邊界。

傳回型別:

tuple[Array, Array, Array]

參見

範例

>>> x = jnp.array([1, 2, 3, 10, 11, 15, 19, 25])
>>> y = jnp.array([2, 5, 6, 8, 13, 16, 17, 18])
>>> counts, x_edges, y_edges = jnp.histogram2d(x, y, bins=8)
>>> counts.shape
(8, 8)
>>> x_edges
Array([ 1.,  4.,  7., 10., 13., 16., 19., 22., 25.], dtype=float32)
>>> y_edges
Array([ 2.,  4.,  6.,  8., 10., 12., 14., 16., 18.], dtype=float32)

指定箱範圍

>>> counts, x_edges, y_edges = jnp.histogram2d(x, y, range=[(0, 25), (0, 25)], bins=5)
>>> counts.shape
(5, 5)
>>> x_edges
Array([ 0.,  5., 10., 15., 20., 25.], dtype=float32)
>>> y_edges
Array([ 0.,  5., 10., 15., 20., 25.], dtype=float32)

明確指定箱邊緣

>>> x_edges = jnp.array([0, 10, 20, 30])
>>> y_edges = jnp.array([0, 10, 20, 30])
>>> counts, _, _ = jnp.histogram2d(x, y, bins=[x_edges, y_edges])
>>> counts
Array([[3, 0, 0],
       [1, 3, 0],
       [0, 1, 0]], dtype=int32)

使用 density=True 傳回正規化直方圖

>>> density, x_edges, y_edges = jnp.histogram2d(x, y, density=True)
>>> dx = jnp.diff(x_edges)
>>> dy = jnp.diff(y_edges)
>>> normed_sum = jnp.sum(density * dx[:, None] * dy[None, :])
>>> jnp.allclose(normed_sum, 1.0)
Array(True, dtype=bool)