jax.lax.reshape#
- jax.lax.reshape(operand, new_sizes, dimensions=None, sharding=None)[原始碼]#
包裝 XLA 的 Reshape 運算子。
對於插入/移除大小為 1 的維度,建議優先使用
lax.squeeze
/lax.expand_dims
。這些保留了關於軸身分識別的資訊,可能對進階轉換規則很有用。- 參數:
operand (ArrayLike) – 要 reshape 的陣列。
new_sizes (Shape) – 指定結果形狀的整數序列。最終陣列的大小必須與輸入的大小相符。
dimensions (Sequence[int] | None | None) – 指定輸入形狀的排列順序的可選整數序列。如果指定,長度必須與
operand.shape
相符。sharding (NamedSharding | P | None | None)
- 返回:
reshape 後的陣列。
- 返回型別:
out
範例
從一維到二維的簡單 reshape
>>> x = jnp.arange(6) >>> y = reshape(x, (2, 3)) >>> y Array([[0, 1, 2], [3, 4, 5]], dtype=int32)
Reshape 回一維
>>> reshape(y, (6,)) Array([0, 1, 2, 3, 4, 5], dtype=int32)
使用維度排列 reshape 為一維
>>> reshape(y, (6,), (1, 0)) Array([0, 3, 1, 4, 2, 5], dtype=int32)