jax.lax.reshape#

jax.lax.reshape(operand, new_sizes, dimensions=None, sharding=None)[原始碼]#

包裝 XLA 的 Reshape 運算子。

對於插入/移除大小為 1 的維度,建議優先使用 lax.squeeze / lax.expand_dims。這些保留了關於軸身分識別的資訊,可能對進階轉換規則很有用。

參數:
  • operand (ArrayLike) – 要 reshape 的陣列。

  • new_sizes (Shape) – 指定結果形狀的整數序列。最終陣列的大小必須與輸入的大小相符。

  • dimensions (Sequence[int] | None | None) – 指定輸入形狀的排列順序的可選整數序列。如果指定,長度必須與 operand.shape 相符。

  • sharding (NamedSharding | P | None | None)

返回:

reshape 後的陣列。

返回型別:

out

範例

從一維到二維的簡單 reshape

>>> x = jnp.arange(6)
>>> y = reshape(x, (2, 3))
>>> y
Array([[0, 1, 2],
             [3, 4, 5]], dtype=int32)

Reshape 回一維

>>> reshape(y, (6,))
Array([0, 1, 2, 3, 4, 5], dtype=int32)

使用維度排列 reshape 為一維

>>> reshape(y, (6,), (1, 0))
Array([0, 3, 1, 4, 2, 5], dtype=int32)