jax.numpy.median#
- jax.numpy.median(a, axis=None, out=None, overwrite_input=False, keepdims=False)[原始碼]#
沿著給定軸回傳陣列元素的中位數。
JAX 版本的
numpy.median()
。- 參數:
- 回傳值:
沿著給定軸的中位數陣列。
- 回傳型別:
另請參閱
jax.numpy.mean()
:計算給定軸上陣列元素的平均值。jax.numpy.max()
:計算給定軸上陣列元素的最大值。jax.numpy.min()
:計算給定軸上陣列元素的最小值。
範例
預設情況下,中位數是針對展平陣列計算。
>>> x = jnp.array([[2, 4, 7, 1], ... [3, 5, 9, 2], ... [6, 1, 8, 3]]) >>> jnp.median(x) Array(3.5, dtype=float32)
若
axis=1
,則沿著軸 1 計算中位數。>>> jnp.median(x, axis=1) Array([3. , 4. , 4.5], dtype=float32)
若
keepdims=True
,則輸出的ndim
等於輸入的ndim
。>>> jnp.median(x, axis=1, keepdims=True) Array([[3. ], [4. ], [4.5]], dtype=float32)