jax.lax.collapse# jax.lax.collapse(operand, start_dimension, stop_dimension=None)[原始碼]# 將陣列的維度摺疊成單一維度。 例如,如果 operand 是一個形狀為 [2, 3, 4] 的陣列,則 collapse(operand, 0, 2).shape == [6, 4]。摺疊維度的元素以 major-to-minor 方式佈局,即編號最低的維度是變化最慢的維度。 參數: operand (Array) – 輸入陣列。 start_dimension (int) – 要摺疊的維度的起始位置(包含)。 stop_dimension (int | None | None) – 要摺疊的維度的結束位置(不包含)。傳遞 None 以摺疊起始位置之後的所有維度。 傳回: 維度 [start_dimension, stop_dimension) 已摺疊(展平)成單一維度的陣列。 傳回類型: Array