jax.numpy.nanmedian#
- jax.numpy.nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False)[原始碼]#
傳回沿著給定軸的陣列元素中位數,忽略 NaN 值。
JAX 版本的
numpy.nanmedian()
。- 參數:
- 回傳:
一個陣列,包含沿著給定軸的中位數,忽略 NaN 值。如果沿著給定軸的所有元素都是 NaN,則傳回
nan
。- 回傳類型:
另請參閱
jax.numpy.nanmean()
:計算沿著給定軸的陣列元素平均值,忽略 NaN 值。jax.numpy.nanmax()
:計算沿著給定軸的陣列元素最大值,忽略 NaN 值。jax.numpy.nanmin()
:計算沿著給定軸的陣列元素最小值,忽略 NaN 值。
範例
預設情況下,中位數是針對扁平化陣列計算的。
>>> nan = jnp.nan >>> x = jnp.array([[2, nan, 7, nan], ... [nan, 5, 9, 2], ... [6, 1, nan, 3]]) >>> jnp.nanmedian(x) Array(4., dtype=float32)
如果
axis=1
,則沿著軸 1 計算中位數。>>> jnp.nanmedian(x, axis=1) Array([4.5, 5. , 3. ], dtype=float32)
如果
keepdims=True
,則輸出的ndim
等於輸入的ndim
。>>> jnp.nanmedian(x, axis=1, keepdims=True) Array([[4.5], [5. ], [3. ]], dtype=float32)