jax.numpy.transpose#

jax.numpy.transpose(a, axes=None)[source]#

返回 N 維陣列的轉置版本。

JAX 對 numpy.transpose() 的實作,以 jax.lax.transpose() 實作。

參數::
  • a (ArrayLike) – 輸入陣列

  • axes (Sequence[int] | None | None) – 可選地指定排列,使用長度為 a.ndim 的整數序列 i,滿足 0 <= i < a.ndim。預設為 range(a.ndim)[::-1],即反轉所有軸的順序。

返回::

陣列的轉置副本。

返回類型::

Array

另請參閱

注意

numpy.transpose() 不同,jax.numpy.transpose() 將返回副本而不是輸入陣列的視圖。然而,在 JIT 下,編譯器將最佳化消除這些副本,因此這在實踐中不會對效能產生影響。

範例

對於 1D 陣列,轉置是恆等式

>>> x = jnp.array([1, 2, 3, 4])
>>> jnp.transpose(x)
Array([1, 2, 3, 4], dtype=int32)

對於 2D 陣列,轉置是矩陣轉置

>>> x = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.transpose(x)
Array([[1, 3],
       [2, 4]], dtype=int32)

對於 N 維陣列,轉置會反轉軸的順序

>>> x = jnp.zeros(shape=(3, 4, 5))
>>> jnp.transpose(x).shape
(5, 4, 3)

可以指定 axes 參數來更改此預設行為

>>> jnp.transpose(x, (0, 2, 1)).shape
(3, 5, 4)

由於交換最後兩個軸是常見操作,因此可以透過其自身的 API jax.numpy.matrix_transpose() 完成

>>> jnp.matrix_transpose(x).shape
(3, 5, 4)

為了方便起見,轉置也可以使用 jax.Array.transpose() 方法或 jax.Array.T 屬性執行

>>> x = jnp.array([[1, 2],
...                [3, 4]])
>>> x.transpose()
Array([[1, 3],
       [2, 4]], dtype=int32)
>>> x.T
Array([[1, 3],
       [2, 4]], dtype=int32)