jax.flatten_util.ravel_pytree#
- jax.flatten_util.ravel_pytree(pytree)[原始碼]#
將陣列的 pytree 展平(ravel)成一維陣列。
- 參數:
pytree – 要展平的陣列和純量 pytree。
- 傳回:
一個配對,其中第一個元素是一維陣列,表示展平且串聯的葉節點值,其 dtype 由提升葉節點值的 dtype 決定,第二個元素是一個可調用物件,用於將相同長度的一維向量反展平回與輸入
pytree
結構相同的 pytree。如果輸入 pytree 為空(即沒有葉節點),則按照慣例,將在輸出的第一個組件中傳回 dtype 為 float32 的一維空陣列。
有關 dtype 提升的詳細資訊,請參閱 https://jax.dev.org.tw/en/latest/type_promotion.html。