jax.numpy.linalg.matrix_transpose#
- jax.numpy.linalg.matrix_transpose(x, /)[原始碼]#
轉置矩陣或矩陣堆疊。
JAX 實作的
numpy.linalg.matrix_transpose()
。- 參數:
x (ArrayLike) – 形狀為
(..., M, N)
的陣列- 傳回:
形狀為
(..., N, M)
的陣列,包含x
的矩陣轉置。- 傳回類型:
參見
jax.numpy.transpose()
:更通用的轉置運算。範例
單一矩陣的轉置
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.matrix_transpose(x) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
矩陣堆疊的轉置
>>> x = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.linalg.matrix_transpose(x) Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)
為了方便起見,相同的計算可以透過 JAX 陣列物件的
mT
屬性完成>>> x.mT Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)