Pytree#
什麼是 pytree?#
在 JAX 中,我們使用術語 *pytree* 來指稱由類似容器的 Python 物件建構而成的樹狀結構。如果類別在 pytree 登錄檔中,則會被視為類似容器,預設情況下,這包括列表、元組和字典。也就是說
類型 *不在* pytree 容器登錄檔中的任何物件都被視為 *葉節點* pytree;
類型在 pytree 容器登錄檔中,且包含 pytree 的任何物件都被視為 pytree。
對於 pytree 容器登錄檔中的每個條目,都會註冊類似容器的類型,並配對函數,這些函數指定如何將容器類型的實例轉換為 (children, metadata)
配對,以及如何將此類配對轉換回容器類型的實例。使用這些函數,JAX 可以將任何已註冊容器物件的樹狀結構標準化為元組。
Pytree 範例
[1, "a", object()] # 3 leaves
(1, (2, 3), ()) # 3 leaves
[1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves
JAX 可以擴充以將其他容器類型視為 pytree;請參閱下方的 擴充 pytree。
Pytree 和 JAX 函數#
許多 JAX 函數,例如 jax.lax.scan()
,都在陣列的 pytree 上運作。JAX 函數轉換可以應用於接受陣列的 pytree 作為輸入並產生陣列的 pytree 作為輸出的函數。
將可選參數應用於 pytree#
某些 JAX 函數轉換會採用可選參數,這些參數指定應如何處理某些輸入或輸出值(例如,vmap()
的 in_axes
和 out_axes
參數)。這些參數也可以是 pytree,並且它們的結構必須對應於相應引數的 pytree 結構。特別是,為了能夠將這些參數 pytree 中的葉節點與引數 pytree 中的值「匹配」,參數 pytree 通常被限制為引數 pytree 的樹狀前綴。
例如,如果我們將以下輸入傳遞給 vmap()
(請注意,函數的輸入引數被視為元組)
(a1, {"k1": a2, "k2": a3})
我們可以使用以下 in_axes
pytree 來指定僅映射 k2
引數 (axis=0
),而其餘引數則不映射 (axis=None
)
(None, {"k1": None, "k2": 0})
可選參數 pytree 結構必須與主輸入 pytree 的結構相符。但是,可選參數可以選擇性地指定為「前綴」pytree,這表示單個葉節點值可以應用於整個子 pytree。例如,如果我們與上述的 vmap()
輸入相同,但希望僅映射字典引數,則可以使用
(None, 0) # equivalent to (None, {"k1": 0, "k2": 0})
或者,如果我們希望映射每個引數,我們可以簡單地編寫一個應用於整個引數元組 pytree 的單個葉節點值
0
這碰巧是 vmap()
的預設 in_axes
值!
相同的邏輯適用於引用已轉換函數的特定輸入或輸出值的其他可選參數,例如 vmap
的 out_axes
。
檢視物件的 pytree 定義#
若要檢視任意 object
的 pytree 定義以進行除錯,您可以使用
from jax.tree_util import tree_structure
print(tree_structure(object))
開發者資訊#
這主要是 JAX 內部文件,終端使用者不應該需要理解這些才能使用 JAX,除非向 JAX 註冊新的使用者定義容器類型。其中一些細節可能會變更。
內部 pytree 處理#
JAX 在 api.py
邊界(以及在控制流程基本運算中)將 pytree 平坦化為葉節點列表。這使下游 JAX 內部結構更簡單:例如 grad()
、jit()
和 vmap()
等轉換可以處理接受和傳回各種不同 Python 容器的使用者函數,而系統的所有其他部分都可以在僅採用(多個)陣列引數並始終傳回平坦陣列列表的函數上運作。
當 JAX 平坦化 pytree 時,它會產生葉節點列表和 treedef
物件,該物件編碼原始值的結構。treedef
隨後可用於在轉換葉節點後建構相符的結構化值。Pytree 是樹狀結構,而不是 DAG 狀或圖狀結構,因為我們處理它們時假設參考透明性,並且它們不能包含參考週期。
以下是一個簡單範例
from jax.tree_util import tree_flatten, tree_unflatten
import jax.numpy as jnp
# The structured value to be transformed
value_structured = [1., (2., 3.)]
# The leaves in value_flat correspond to the `*` markers in value_tree
value_flat, value_tree = tree_flatten(value_structured)
print(f"{value_flat=}\n{value_tree=}")
# Transform the flat value list using an element-wise numeric transformer
transformed_flat = list(map(lambda v: v * 2., value_flat))
print(f"{transformed_flat=}")
# Reconstruct the structured output, using the original
transformed_structured = tree_unflatten(value_tree, transformed_flat)
print(f"{transformed_structured=}")
value_flat=[1.0, 2.0, 3.0]
value_tree=PyTreeDef([*, (*, *)])
transformed_flat=[2.0, 4.0, 6.0]
transformed_structured=[2.0, (4.0, 6.0)]
預設情況下,pytree 容器可以是列表、元組、字典、namedtuple、None、OrderedDict。其他類型的值,包括數值和 ndarray 值,都被視為葉節點
from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])
example_containers = [
(1., [2., 3.]),
(1., {'b': 2., 'a': 3.}),
1.,
None,
jnp.zeros(2),
Point(1., 2.)
]
def show_example(structured):
flat, tree = tree_flatten(structured)
unflattened = tree_unflatten(tree, flat)
print(f"{structured=}\n {flat=}\n {tree=}\n {unflattened=}")
for structured in example_containers:
show_example(structured)
structured=(1.0, [2.0, 3.0])
flat=[1.0, 2.0, 3.0]
tree=PyTreeDef((*, [*, *]))
unflattened=(1.0, [2.0, 3.0])
structured=(1.0, {'b': 2.0, 'a': 3.0})
flat=[1.0, 3.0, 2.0]
tree=PyTreeDef((*, {'a': *, 'b': *}))
unflattened=(1.0, {'a': 3.0, 'b': 2.0})
structured=1.0
flat=[1.0]
tree=PyTreeDef(*)
unflattened=1.0
structured=None
flat=[]
tree=PyTreeDef(None)
unflattened=None
structured=Array([0., 0.], dtype=float32)
flat=[Array([0., 0.], dtype=float32)]
tree=PyTreeDef(*)
unflattened=Array([0., 0.], dtype=float32)
structured=Point(x=1.0, y=2.0)
flat=[1.0, 2.0]
tree=PyTreeDef(CustomNode(namedtuple[Point], [*, *]))
unflattened=Point(x=1.0, y=2.0)
擴充 pytree#
預設情況下,結構化值中未被識別為內部 pytree 節點(即類似容器)的任何部分都被視為葉節點
class Special(object):
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return "Special(x={}, y={})".format(self.x, self.y)
show_example(Special(1., 2.))
structured=Special(x=1.0, y=2.0)
flat=[Special(x=1.0, y=2.0)]
tree=PyTreeDef(*)
unflattened=Special(x=1.0, y=2.0)
被視為內部 pytree 節點的 Python 類型集合是可擴充的,透過類型的全域登錄檔,並且將遞迴地遍歷已註冊類型的值。若要註冊新類型,您可以使用 register_pytree_node()
from jax.tree_util import register_pytree_node
class RegisteredSpecial(Special):
def __repr__(self):
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
def special_flatten(v):
"""Specifies a flattening recipe.
Params:
v: the value of registered type to flatten.
Returns:
a pair of an iterable with the children to be flattened recursively,
and some opaque auxiliary data to pass back to the unflattening recipe.
The auxiliary data is stored in the treedef for use during unflattening.
The auxiliary data could be used, e.g., for dictionary keys.
"""
children = (v.x, v.y)
aux_data = None
return (children, aux_data)
def special_unflatten(aux_data, children):
"""Specifies an unflattening recipe.
Params:
aux_data: the opaque data that was specified during flattening of the
current treedef.
children: the unflattened children
Returns:
a re-constructed object of the registered type, using the specified
children and auxiliary data.
"""
return RegisteredSpecial(*children)
# Global registration
register_pytree_node(
RegisteredSpecial,
special_flatten, # tell JAX what are the children nodes
special_unflatten # tell JAX how to pack back into a RegisteredSpecial
)
show_example(RegisteredSpecial(1., 2.))
structured=RegisteredSpecial(x=1.0, y=2.0)
flat=[1.0, 2.0]
tree=PyTreeDef(CustomNode(RegisteredSpecial[None], [*, *]))
unflattened=RegisteredSpecial(x=1.0, y=2.0)
或者,您可以在您的類別上定義適當的 tree_flatten
和 tree_unflatten
方法,並使用 register_pytree_node_class()
裝飾它
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class RegisteredSpecial2(Special):
def __repr__(self):
return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y)
def tree_flatten(self):
children = (self.x, self.y)
aux_data = None
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
show_example(RegisteredSpecial2(1., 2.))
structured=RegisteredSpecial2(x=1.0, y=2.0)
flat=[1.0, 2.0]
tree=PyTreeDef(CustomNode(RegisteredSpecial2[None], [*, *]))
unflattened=RegisteredSpecial2(x=1.0, y=2.0)
在定義 unflattening 函數時,一般來說,children
應包含資料結構的所有動態元素(陣列、動態純量和 pytree),而 aux_data
應包含將滾動到 treedef
結構中的所有靜態元素。JAX 有時需要比較 treedef
是否相等,或計算其雜湊值以用於 JIT 快取,因此必須注意確保在平坦化配方中指定的輔助資料支援有意義的雜湊和相等性比較。
用於操作 pytree 的整組函數都在 jax.tree_util
中。
自訂 PyTree 和初始化#
使用者定義的 PyTree 物件的一個常見陷阱是,JAX 轉換有時會使用非預期的值初始化它們,因此在初始化時完成的任何輸入驗證都可能會失敗。例如
class MyTree:
def __init__(self, a):
self.a = jnp.asarray(a)
register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
lambda _, args: MyTree(*args))
tree = MyTree(jnp.arange(5.0))
jax.vmap(lambda x: x)(tree) # Error because object() is passed to MyTree.
jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to MyTree
在第一種情況下,JAX 的內部結構使用 object()
值的陣列來推斷樹狀結構;在第二種情況下,將樹狀結構映射到樹狀結構的函數的雅可比矩陣定義為樹狀結構的樹狀結構。
因此,自訂 PyTree 類別的 __init__
和 __new__
方法通常應避免執行任何陣列轉換或其他輸入驗證,否則應預期並處理這些特殊情況。例如
class MyTree:
def __init__(self, a):
if not (type(a) is object or a is None or isinstance(a, MyTree)):
a = jnp.asarray(a)
self.a = a
另一種可能性是建構您的 tree_unflatten
函數,使其避免呼叫 __init__
;例如
def tree_unflatten(aux_data, children):
del aux_data # unused in this class
obj = object.__new__(MyTree)
obj.a = a
return obj
如果您採用這種方法,請確保您的 tree_unflatten
函數在程式碼更新時與 __init__
保持同步。