jax.numpy.linalg.matrix_norm#

jax.numpy.linalg.matrix_norm(x, /, *, keepdims=False, ord='fro')[原始碼]#

計算矩陣或矩陣堆疊的範數。

JAX 實作的 numpy.linalg.matrix_norm()

參數:
  • x (ArrayLike) – 形狀為 (..., M, N) 的陣列,用於計算範數。

  • keepdims (bool) – 若為 True,則在輸出中保留縮減的維度。

  • ord (str | int) – 字串或整數,指定範數的類型;預設值為 Frobenius 範數。如需可用選項的詳細資訊,請參閱 numpy.linalg.norm()

傳回:

包含 x 範數的陣列。若 keepdims 為 False,則形狀為 x.shape[:-2];若 keepdims 為 True,則形狀為 (..., 1, 1)

傳回類型:

Array

另請參閱

範例

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> jnp.linalg.matrix_norm(x)
Array(16.881943, dtype=float32)