jax.scipy.stats.sem#

jax.scipy.stats.sem(a, axis=0, ddof=1, nan_policy='propagate', *, keepdims=False)[原始碼]#

計算平均數的標準誤。

JAX 實作的 scipy.stats.sem()

參數:
  • a (ArrayLike) – 類陣列

  • axis (int | None) – 選用整數。若未指定,則輸入陣列會被展平。

  • ddof (int) – 整數,預設值=1。SEM 計算中的自由度。

  • nan_policy (str) – 字串,預設值=”propagate”。JAX 僅支援 “propagate” 和 “omit”。

  • keepdims (bool) – 布林值,預設值=False。若為 true,則縮減的軸會保留在結果中,大小為 1。

傳回:

陣列

傳回型別:

Array

範例

>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x)
Array(0.41, dtype=float32)

對於多維陣列,sem 會沿著 axis=0 計算平均數的標準誤

>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
...                 [3, 1, 3, 2, 1, 3],
...                 [1, 2, 2, 3, 1, 2]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x1)
Array([0.67, 0.33, 0.58, 0.33, 0.33, 0.58], dtype=float32)

axis=1,平均數的標準誤將沿著 axis 1 計算。

>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x1, axis=1)
Array([0.33, 0.4 , 0.31], dtype=float32)

axis=None,平均數的標準誤將沿著所有軸計算。

>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x1, axis=None)
Array(0.2, dtype=float32)

預設情況下,sem 會縮減結果的維度。若要保持維度與輸入陣列相同,則必須將引數 keepdims 設定為 True

>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x1, axis=1, keepdims=True)
Array([[0.33],
       [0.4 ],
       [0.31]], dtype=float32)

由於預設情況下,nan_policy='propagate'sem 會將 nan 值傳播到結果中。

>>> nan = jnp.nan
>>> x2 = jnp.array([[1, 2, 3, nan, 4, 2],
...                 [4, 5, 4, 3, nan, 1],
...                 [7, nan, 8, 7, 9, nan]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x2)
Array([1.73,  nan, 1.53,  nan,  nan,  nan], dtype=float32)

nan_policy='omit`sem 會省略 nan 值,並計算指定軸上剩餘值的誤差。

>>> with jnp.printoptions(precision=2, suppress=True):
...   jax.scipy.stats.sem(x2, nan_policy='omit')
Array([1.73, 1.5 , 1.53, 2.  , 2.5 , 0.5 ], dtype=float32)