jax.numpy.median#

jax.numpy.median(a, axis=None, out=None, overwrite_input=False, keepdims=False)[原始碼]#

沿著給定軸回傳陣列元素的中位數。

JAX 版本的 numpy.median()

參數:
  • a (ArrayLike) – 輸入陣列。

  • axis (int | tuple[int, ...] | None) – 選用,整數或整數序列,預設值=None。計算中位數的軸。若為 None,則針對展平陣列計算中位數。

  • keepdims (bool) – 布林值,預設值=False。若為 true,則縮減的軸會保留在結果中,且大小為 1。

  • out (None) – JAX 未使用。

  • overwrite_input (bool) – JAX 未使用。

回傳值:

沿著給定軸的中位數陣列。

回傳型別:

Array

另請參閱

範例

預設情況下,中位數是針對展平陣列計算。

>>> x = jnp.array([[2, 4, 7, 1],
...                [3, 5, 9, 2],
...                [6, 1, 8, 3]])
>>> jnp.median(x)
Array(3.5, dtype=float32)

axis=1,則沿著軸 1 計算中位數。

>>> jnp.median(x, axis=1)
Array([3. , 4. , 4.5], dtype=float32)

keepdims=True,則輸出的 ndim 等於輸入的 ndim

>>> jnp.median(x, axis=1, keepdims=True)
Array([[3. ],
       [4. ],
       [4.5]], dtype=float32)