jax.numpy.reshape#
- jax.numpy.reshape(a, shape=None, order='C', *, newshape=Deprecated, copy=None)[原始碼]#
傳回陣列的重塑副本。
JAX 版本的
numpy.reshape()
,以jax.lax.reshape()
實作。- 參數:
a (ArrayLike) – 要重塑的輸入陣列
shape (DimSize | Shape | None | None) – 整數或整數序列,指定新形狀,必須與輸入陣列的大小相符。如果任何單一維度的大小為
-1
,它將被替換為一個值,使輸出具有正確的大小。order (str) –
'F'
或'C'
,指定重塑應套用欄主序(Fortran 風格,"F"
)還是列主序(C 風格,"C"
);預設為"C"
。JAX 不支援order="A"
。copy (bool | None | None) – JAX 未使用;JAX 總是傳回副本,但在 JIT 下,編譯器可能會最佳化掉這些副本。
newshape (DimSize | Shape | DeprecatedArg) –
shape
引數的已棄用別名。如果使用,將導致DeprecationWarning
。
- 傳回:
具有指定形狀的輸入陣列的重塑副本。
- 傳回類型:
注意事項
與
numpy.reshape()
不同,jax.numpy.reshape()
將傳回輸入陣列的副本,而不是視圖。但是,在 JIT 下,編譯器會在可能的情況下最佳化掉這些副本,因此這在實務上不會產生效能影響。另請參閱
jax.Array.reshape()
:透過陣列方法的等效功能。jax.numpy.ravel()
:將陣列展平為 1D 形狀。jax.numpy.squeeze()
:從陣列的形狀中移除一個或多個長度為 1 的軸。
範例
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.reshape(x, 6) Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (3, 2)) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
您可以使用
-1
自動計算與輸入大小一致的形狀>>> jnp.reshape(x, -1) # -1 is inferred to be 6 Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (-1, 2)) # -1 is inferred to be 3 Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
reshape 中軸的預設順序是 C 風格的列主序。若要使用 Fortran 風格的欄主序,請指定
order='F'
>>> jnp.reshape(x, 6, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32) >>> jnp.reshape(x, (3, 2), order='F') Array([[1, 5], [4, 3], [2, 6]], dtype=int32)
為了方便起見,此功能也可透過
jax.Array.reshape()
方法取得>>> x.reshape(3, 2) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)