jax.numpy.reshape#

jax.numpy.reshape(a, shape=None, order='C', *, newshape=Deprecated, copy=None)[原始碼]#

傳回陣列的重塑副本。

JAX 版本的 numpy.reshape(),以 jax.lax.reshape() 實作。

參數:
  • a (ArrayLike) – 要重塑的輸入陣列

  • shape (DimSize | Shape | None | None) – 整數或整數序列,指定新形狀,必須與輸入陣列的大小相符。如果任何單一維度的大小為 -1,它將被替換為一個值,使輸出具有正確的大小。

  • order (str) – 'F''C',指定重塑應套用欄主序(Fortran 風格,"F")還是列主序(C 風格,"C");預設為 "C"。JAX 不支援 order="A"

  • copy (bool | None | None) – JAX 未使用;JAX 總是傳回副本,但在 JIT 下,編譯器可能會最佳化掉這些副本。

  • newshape (DimSize | Shape | DeprecatedArg) – shape 引數的已棄用別名。如果使用,將導致 DeprecationWarning

傳回:

具有指定形狀的輸入陣列的重塑副本。

傳回類型:

Array

注意事項

numpy.reshape() 不同,jax.numpy.reshape() 將傳回輸入陣列的副本,而不是視圖。但是,在 JIT 下,編譯器會在可能的情況下最佳化掉這些副本,因此這在實務上不會產生效能影響。

另請參閱

範例

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.reshape(x, 6)
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>> jnp.reshape(x, (3, 2))
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)

您可以使用 -1 自動計算與輸入大小一致的形狀

>>> jnp.reshape(x, -1)  # -1 is inferred to be 6
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>> jnp.reshape(x, (-1, 2))  # -1 is inferred to be 3
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)

reshape 中軸的預設順序是 C 風格的列主序。若要使用 Fortran 風格的欄主序,請指定 order='F'

>>> jnp.reshape(x, 6, order='F')
Array([1, 4, 2, 5, 3, 6], dtype=int32)
>>> jnp.reshape(x, (3, 2), order='F')
Array([[1, 5],
       [4, 3],
       [2, 6]], dtype=int32)

為了方便起見,此功能也可透過 jax.Array.reshape() 方法取得

>>> x.reshape(3, 2)
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)