jax.numpy.transpose#
- jax.numpy.transpose(a, axes=None)[source]#
返回 N 維陣列的轉置版本。
JAX 對
numpy.transpose()
的實作,以jax.lax.transpose()
實作。- 參數::
a (ArrayLike) – 輸入陣列
axes (Sequence[int] | None | None) – 可選地指定排列,使用長度為 a.ndim 的整數序列
i
,滿足0 <= i < a.ndim
。預設為range(a.ndim)[::-1]
,即反轉所有軸的順序。
- 返回::
陣列的轉置副本。
- 返回類型::
另請參閱
jax.Array.transpose()
:透過Array
方法的等效函數。jax.Array.T
:透過Array
屬性的等效函數。jax.numpy.matrix_transpose()
:轉置陣列的最後兩個軸。這適用於處理批次 2D 矩陣。jax.numpy.swapaxes()
:交換陣列中的任意兩個軸。jax.numpy.moveaxis()
:將軸移動到陣列中的另一個位置。
注意
與
numpy.transpose()
不同,jax.numpy.transpose()
將返回副本而不是輸入陣列的視圖。然而,在 JIT 下,編譯器將最佳化消除這些副本,因此這在實踐中不會對效能產生影響。範例
對於 1D 陣列,轉置是恆等式
>>> x = jnp.array([1, 2, 3, 4]) >>> jnp.transpose(x) Array([1, 2, 3, 4], dtype=int32)
對於 2D 陣列,轉置是矩陣轉置
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.transpose(x) Array([[1, 3], [2, 4]], dtype=int32)
對於 N 維陣列,轉置會反轉軸的順序
>>> x = jnp.zeros(shape=(3, 4, 5)) >>> jnp.transpose(x).shape (5, 4, 3)
可以指定
axes
參數來更改此預設行為>>> jnp.transpose(x, (0, 2, 1)).shape (3, 5, 4)
由於交換最後兩個軸是常見操作,因此可以透過其自身的 API
jax.numpy.matrix_transpose()
完成>>> jnp.matrix_transpose(x).shape (3, 5, 4)
為了方便起見,轉置也可以使用
jax.Array.transpose()
方法或jax.Array.T
屬性執行>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> x.transpose() Array([[1, 3], [2, 4]], dtype=int32) >>> x.T Array([[1, 3], [2, 4]], dtype=int32)