jax.numpy.average#

jax.numpy.average(a, axis=None, weights=None, returned=False, keepdims=False)[原始碼]#

計算權重平均值。

JAX 版本的 numpy.average()

參數:
  • a (ArrayLike) – 要計算平均值的陣列

  • axis (Axis | None) – 一個可選的整數或整數序列,指定要計算平均值的軸。如果未指定,則沿所有軸計算平均值。

  • weights (ArrayLike | None | None) – 權重平均的可選權重陣列。必須與 a 廣播相容。

  • returned (bool) – 如果為 False (預設值),則僅回傳平均值。如果為 True,則同時回傳平均值和正規化因子 (即權重總和)。

  • keepdims (bool) – 如果為 True,則縮減的軸會保留在結果中,大小為 1。如果為 False (預設值),則會擠出縮減的軸。

回傳值:

如果 returned 為 True,則為陣列 average 或陣列元組 (average, normalization)

回傳型別:

Array | tuple[Array, Array]

參見

範例

簡單平均值

>>> x = jnp.array([1, 2, 3, 2, 4])
>>> jnp.average(x)
Array(2.4, dtype=float32)

權重平均值

>>> weights = jnp.array([2, 1, 3, 2, 2])
>>> jnp.average(x, weights=weights)
Array(2.5, dtype=float32)

使用 returned=True 以選擇性地回傳正規化值,即權重總和

>>> jnp.average(x, returned=True)
(Array(2.4, dtype=float32), Array(5., dtype=float32))
>>> jnp.average(x, weights=weights, returned=True)
(Array(2.5, dtype=float32), Array(10., dtype=float32))

沿指定軸的權重平均值

>>> x = jnp.array([[8, 2, 7],
...                [3, 6, 4]])
>>> weights = jnp.array([1, 2, 3])
>>> jnp.average(x, weights=weights, axis=1)
Array([5.5, 4.5], dtype=float32)