jax.scipy.stats.mode#
- jax.scipy.stats.mode(a, axis=0, nan_policy='propagate', keepdims=False)[原始碼]#
計算陣列沿軸的眾數(最常見的值)。
JAX 實作的
scipy.stats.mode()
。- 參數:
- 返回:
陣列的元組,
(mode, count)
。mode
是眾數值的陣列,而count
是每個值在輸入陣列中出現的次數。- 返回類型:
ModeResult
範例
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3]) >>> mode, count = jax.scipy.stats.mode(x) >>> mode, count (Array(4, dtype=int32), Array(3, dtype=int32))
對於多維陣列,
jax.scipy.stats.mode
計算沿axis=0
的mode
和對應的count
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1], ... [3, 1, 3, 2, 1, 3], ... [1, 2, 2, 3, 1, 2]]) >>> mode, count = jax.scipy.stats.mode(x1) >>> mode, count (Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))
如果
axis=1
,則將沿axis 1
計算mode
和count
。>>> mode, count = jax.scipy.stats.mode(x1, axis=1) >>> mode, count (Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))
預設情況下,
jax.scipy.stats.mode
會縮減結果的維度。若要使維度與輸入陣列的維度相同,則必須將引數keepdims
設定為True
。>>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True) >>> mode, count (Array([[1], [3], [2]], dtype=int32), Array([[3], [3], [3]], dtype=int32))