jax.numpy.matrix_transpose#
- jax.numpy.matrix_transpose(x, /)[原始碼]#
轉置陣列的最後兩個維度。
JAX 實作的
numpy.matrix_transpose()
,以jax.lax.transpose()
實作。- 參數:
x (ArrayLike) – 輸入陣列,必須有
x.ndim >= 2
- 回傳:
陣列的矩陣轉置副本。
- 回傳類型:
另請參閱
jax.Array.mT
:透過Array()
屬性存取的相同操作。jax.numpy.transpose()
:一般多軸轉置
注意
與
numpy.matrix_transpose()
不同,jax.numpy.matrix_transpose()
將回傳輸入陣列的副本而不是視圖。然而,在 JIT 下,編譯器會在可能的情況下最佳化掉這些副本,因此這在實務上不會對效能產生影響。範例
這是一個 2x2x2 矩陣,表示批次化的 2x2 矩陣
>>> x = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.matrix_transpose(x) Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)
為了方便起見,您可以透過
mT
屬性執行相同的轉置jax.Array
>>> x.mT Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)