jax.scipy.stats.mode#

jax.scipy.stats.mode(a, axis=0, nan_policy='propagate', keepdims=False)[原始碼]#

計算陣列沿軸的眾數(最常見的值)。

JAX 實作的 scipy.stats.mode()

參數:
  • a (ArrayLike) – 類陣列

  • axis (int | None) – int,預設值=0。計算眾數的軸。

  • nan_policy (str) – str。JAX 僅支援 "propagate"

  • keepdims (bool) – bool,預設值=False。如果為 true,則縮減的軸會保留在結果中,大小為 1。

返回:

陣列的元組,(mode, count)mode 是眾數值的陣列,而 count 是每個值在輸入陣列中出現的次數。

返回類型:

ModeResult

範例

>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
>>> mode, count = jax.scipy.stats.mode(x)
>>> mode, count
(Array(4, dtype=int32), Array(3, dtype=int32))

對於多維陣列,jax.scipy.stats.mode 計算沿 axis=0mode 和對應的 count

>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
...                 [3, 1, 3, 2, 1, 3],
...                 [1, 2, 2, 3, 1, 2]])
>>> mode, count = jax.scipy.stats.mode(x1)
>>> mode, count
(Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))

如果 axis=1,則將沿 axis 1 計算 modecount

>>> mode, count = jax.scipy.stats.mode(x1, axis=1)
>>> mode, count
(Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))

預設情況下,jax.scipy.stats.mode 會縮減結果的維度。若要使維度與輸入陣列的維度相同,則必須將引數 keepdims 設定為 True

>>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True)
>>> mode, count
(Array([[1],
       [3],
       [2]], dtype=int32), Array([[3],
       [3],
       [3]], dtype=int32))