jax.numpy.nanmedian#

jax.numpy.nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False)[原始碼]#

傳回沿著給定軸的陣列元素中位數,忽略 NaN 值。

JAX 版本的 numpy.nanmedian()

參數:
  • a (ArrayLike) – 輸入陣列。

  • axis (int | tuple[int, ...] | None) – 選項,整數或整數序列,預設值=None。計算中位數的軸。如果為 None,則計算扁平化陣列的中位數。

  • keepdims (bool) – 布林值,預設值=False。如果為 true,則縮減的軸會保留在結果中,大小為 1。

  • out (None) – JAX 未使用。

  • overwrite_input (bool) – JAX 未使用。

回傳:

一個陣列,包含沿著給定軸的中位數,忽略 NaN 值。如果沿著給定軸的所有元素都是 NaN,則傳回 nan

回傳類型:

Array

另請參閱

範例

預設情況下,中位數是針對扁平化陣列計算的。

>>> nan = jnp.nan
>>> x = jnp.array([[2, nan, 7, nan],
...                [nan, 5, 9, 2],
...                [6, 1, nan, 3]])
>>> jnp.nanmedian(x)
Array(4., dtype=float32)

如果 axis=1,則沿著軸 1 計算中位數。

>>> jnp.nanmedian(x, axis=1)
Array([4.5, 5. , 3. ], dtype=float32)

如果 keepdims=True,則輸出的 ndim 等於輸入的 ndim

>>> jnp.nanmedian(x, axis=1, keepdims=True)
Array([[4.5],
       [5. ],
       [3. ]], dtype=float32)