jax.numpy.average#
- jax.numpy.average(a, axis=None, weights=None, returned=False, keepdims=False)[原始碼]#
計算權重平均值。
JAX 版本的
numpy.average()
。- 參數:
- 回傳值:
如果
returned
為 True,則為陣列average
或陣列元組(average, normalization)
。- 回傳型別:
參見
jax.numpy.mean()
:未加權平均值。
範例
簡單平均值
>>> x = jnp.array([1, 2, 3, 2, 4]) >>> jnp.average(x) Array(2.4, dtype=float32)
權重平均值
>>> weights = jnp.array([2, 1, 3, 2, 2]) >>> jnp.average(x, weights=weights) Array(2.5, dtype=float32)
使用
returned=True
以選擇性地回傳正規化值,即權重總和>>> jnp.average(x, returned=True) (Array(2.4, dtype=float32), Array(5., dtype=float32)) >>> jnp.average(x, weights=weights, returned=True) (Array(2.5, dtype=float32), Array(10., dtype=float32))
沿指定軸的權重平均值
>>> x = jnp.array([[8, 2, 7], ... [3, 6, 4]]) >>> weights = jnp.array([1, 2, 3]) >>> jnp.average(x, weights=weights, axis=1) Array([5.5, 4.5], dtype=float32)