jax.numpy.std#

jax.numpy.std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, correction=None)[原始碼]#

沿著給定軸計算標準差。

numpy.std() 的 JAX 實作。

參數:
  • a (ArrayLike) – 輸入陣列。

  • axis (Axis | None) – 選用,整數或整數序列,預設值=None。計算標準差的軸。如果為 None,則沿所有軸計算標準差。

  • dtype (DTypeLike | None | None) – 輸出陣列的類型。預設值=None。

  • ddof (int) – 整數,預設值=0。自由度。標準差計算中的除數為 N-ddofN 是沿給定軸的元素數。

  • keepdims (bool) – 布林值,預設值=False。如果為 true,則縮減的軸會保留在結果中,大小為 1。

  • where (ArrayLike | None | None) – 選用,布林陣列,預設值=None。用於標準差的元素。陣列應與輸入廣播相容。

  • correction (int | float | None | None) – 整數或浮點數,預設值=None。ddof 的替代名稱。ddof 和 correction 不能同時提供。

  • out (None | None) – JAX 未使用。

傳回:

沿著給定軸的標準差陣列。

傳回類型:

Array

另請參閱

範例

預設情況下,jnp.std 會計算沿所有軸的標準差。

>>> x = jnp.array([[1, 3, 4, 2],
...                [4, 2, 5, 3],
...                [5, 4, 2, 3]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.std(x)
Array(1.21, dtype=float32)

如果 axis=0,則沿軸 0 計算。

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.std(x, axis=0))
[1.7  0.82 1.25 0.47]

若要保留輸入的維度,您可以設定 keepdims=True

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.std(x, axis=0, keepdims=True))
[[1.7  0.82 1.25 0.47]]

如果 ddof=1

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.std(x, axis=0, keepdims=True, ddof=1))
[[2.08 1.   1.53 0.58]]

若要包含陣列的特定元素來計算標準差,您可以使用 where

>>> where = jnp.array([[1, 0, 1, 0],
...                    [0, 1, 0, 1],
...                    [1, 1, 1, 0]], dtype=bool)
>>> jnp.std(x, axis=0, keepdims=True, where=where)
Array([[2., 1., 1., 0.]], dtype=float32)