jax.numpy.linalg.norm#

jax.numpy.linalg.norm(x, ord=None, axis=None, keepdims=False)[原始碼]#

計算矩陣或向量的範數。

JAX 實作的 numpy.linalg.norm()

參數:
  • x (ArrayLike) – 將計算範數的 N 維陣列。

  • ord (int | str | None) – 指定要採用的範數種類。預設值為矩陣的 Frobenius 範數,以及向量的 2-範數。如需其他選項,請參閱下方的「註解」。

  • axis (None | tuple[int, ...] | int) – 整數或整數序列,指定將計算範數的軸。預設值為 x 的所有軸。

  • keepdims (bool) – 若為 True,輸出陣列的維度數量會與輸入相同,而縮減軸的大小會替換為 1 (預設值:False)。

傳回值:

包含 x 指定範數的陣列。

傳回型別:

Array

筆記

計算出的範數類型取決於 ord 的值和縮減軸的數量。

對於向量範數 (即單軸縮減)

  • ord=None (預設值) 計算 2-範數

  • ord=inf 計算 max(abs(x))

  • ord=-inf 計算 min(abs(x))``

  • ord=0 計算 sum(x!=0)

  • 對於其他數值,計算 sum(abs(x) ** ord)**(1/ord)

對於矩陣範數 (即雙軸縮減)

  • ord='fro'ord=None (預設值) 計算 Frobenius 範數

  • ord='nuc' 計算核範數,或奇異值的總和

  • ord=1 計算 max(abs(x).sum(0))

  • ord=-1 計算 min(abs(x).sum(0))

  • ord=2 計算 2-範數,即最大的奇異值

  • ord=-2 計算最小的奇異值

範例

向量範數

>>> x = jnp.array([3., 4., 12.])
>>> jnp.linalg.norm(x)
Array(13., dtype=float32)
>>> jnp.linalg.norm(x, ord=1)
Array(19., dtype=float32)
>>> jnp.linalg.norm(x, ord=0)
Array(3., dtype=float32)

矩陣範數

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 7.]])
>>> jnp.linalg.norm(x)  # Frobenius norm
Array(10.198039, dtype=float32)
>>> jnp.linalg.norm(x, ord='nuc')  # nuclear norm
Array(10.762535, dtype=float32)
>>> jnp.linalg.norm(x, ord=1)  # 1-norm
Array(10., dtype=float32)

批次向量範數

>>> jnp.linalg.norm(x, axis=1)
Array([3.7416575, 9.486833 ], dtype=float32)