jax.lax.collapse#

jax.lax.collapse(operand, start_dimension, stop_dimension=None)[原始碼]#

將陣列的維度摺疊成單一維度。

例如,如果 operand 是一個形狀為 [2, 3, 4] 的陣列,則 collapse(operand, 0, 2).shape == [6, 4]。摺疊維度的元素以 major-to-minor 方式佈局,即編號最低的維度是變化最慢的維度。

參數:
  • operand (Array) – 輸入陣列。

  • start_dimension (int) – 要摺疊的維度的起始位置(包含)。

  • stop_dimension (int | None | None) – 要摺疊的維度的結束位置(不包含)。傳遞 None 以摺疊起始位置之後的所有維度。

傳回:

維度 [start_dimension, stop_dimension) 已摺疊(展平)成單一維度的陣列。

傳回類型:

Array