jax.numpy.matrix_transpose#

jax.numpy.matrix_transpose(x, /)[原始碼]#

轉置陣列的最後兩個維度。

JAX 實作的 numpy.matrix_transpose(),以 jax.lax.transpose() 實作。

參數:

x (ArrayLike) – 輸入陣列,必須有 x.ndim >= 2

回傳:

陣列的矩陣轉置副本。

回傳類型:

Array

另請參閱

注意

numpy.matrix_transpose() 不同,jax.numpy.matrix_transpose() 將回傳輸入陣列的副本而不是視圖。然而,在 JIT 下,編譯器會在可能的情況下最佳化掉這些副本,因此這在實務上不會對效能產生影響。

範例

這是一個 2x2x2 矩陣,表示批次化的 2x2 矩陣

>>> x = jnp.array([[[1, 2],
...                 [3, 4]],
...                [[5, 6],
...                 [7, 8]]])
>>> jnp.matrix_transpose(x)
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)

為了方便起見,您可以透過 mT 屬性執行相同的轉置 jax.Array

>>> x.mT
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)