jax.numpy.linalg.matrix_transpose#

jax.numpy.linalg.matrix_transpose(x, /)[原始碼]#

轉置矩陣或矩陣堆疊。

JAX 實作的 numpy.linalg.matrix_transpose()

參數:

x (ArrayLike) – 形狀為 (..., M, N) 的陣列

傳回:

形狀為 (..., N, M) 的陣列,包含 x 的矩陣轉置。

傳回類型:

Array

參見

jax.numpy.transpose():更通用的轉置運算。

範例

單一矩陣的轉置

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.linalg.matrix_transpose(x)
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

矩陣堆疊的轉置

>>> x = jnp.array([[[1, 2],
...                 [3, 4]],
...                [[5, 6],
...                 [7, 8]]])
>>> jnp.linalg.matrix_transpose(x)
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)

為了方便起見,相同的計算可以透過 JAX 陣列物件的 mT 屬性完成

>>> x.mT
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)