jax.tree.flatten#

jax.tree.flatten(tree, is_leaf=None)[原始碼]#

展平 pytree。

展平順序(即輸出列表中元素的順序)是確定的,對應於從左到右的深度優先樹遍歷。

參數:
  • tree (Any) – 要展平的 pytree。

  • is_leaf (Callable[[Any], bool] | None | None) – 一個可選指定的函數,將在每個展平步驟中調用。它應該返回一個布林值,true 表示停止遍歷且整個子樹被視為葉節點,false 表示展平應該遍歷當前物件。

返回:

一個 pair,其中第一個元素是葉值的列表,第二個元素是表示展平樹結構的 treedef。

返回型別:

tuple[list[tree_util.Leaf], tree_util.PyTreeDef]

範例

>>> import jax
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
>>> vals
[1, 2, 3, 4, 5]
>>> treedef
PyTreeDef([*, (*, *), [*, *]])