使用 pytree#
JAX 內建支援看起來像陣列字典 (dicts)、或字典列表的列表或其他巢狀結構的物件 — 在 JAX 中,這些稱為 pytree。本節將說明如何使用它們、提供有用的程式碼範例,並指出常見的「陷阱」和模式。
什麼是 pytree?#
pytree 是一種容器狀結構,由容器狀 Python 物件 — 「葉」pytree 和/或更多 pytree 建構而成。pytree 可以包含列表、元組和字典。葉是任何非 pytree 的事物,例如陣列,但單一葉也是一個 pytree。
在機器學習 (ML) 的背景下,pytree 可以包含
模型參數
資料集條目
強化學習代理程式觀察
當使用資料集時,您經常會遇到 pytree (例如字典列表的列表)。
以下是一個簡單 pytree 的範例。在 JAX 中,您可以使用 jax.tree.leaves()
,從樹狀結構中提取扁平化的葉節點,如下所示
import jax
import jax.numpy as jnp
example_trees = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]
# Print how many leaves the pytrees have.
for pytree in example_trees:
# This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
leaves = jax.tree.leaves(pytree)
print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x7fa718bffe20>] has 3 leaves: [1, 'a', <object object at 0x7fa718bffe20>]
(1, (2, 3), ()) has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32) has 1 leaves: [Array([1, 2, 3], dtype=int32)]
任何由容器狀 Python 物件建構而成的樹狀結構,都可以在 JAX 中視為 pytree。如果類別位於 pytree 登錄檔中,則視為容器狀,預設情況下,這包括列表、元組和字典。任何類型未在 pytree 容器登錄檔中的物件,都將在樹狀結構中視為葉節點。
pytree 登錄檔可以擴充,以包含使用者定義的容器類別,方法是使用指定如何扁平化樹狀結構的函式來登錄類別;請參閱下方的自訂 pytree 節點。
常見的 pytree 函式#
JAX 提供了許多公用程式來對 pytree 進行操作。這些可以在 jax.tree_util
子套件中找到;為了方便起見,其中許多在 jax.tree
模組中都有別名。
常用函式:jax.tree.map
#
最常用的 pytree 函式是 jax.tree.map()
。它的運作方式類似於 Python 的原生 map
,但可以透明地對整個 pytree 進行操作。
以下範例
list_of_lists = [
[1, 2, 3],
[1, 2],
[1, 2, 3, 4]
]
jax.tree.map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
jax.tree.map()
也允許將 N 元函式映射到多個引數。例如
another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
當使用多個引數搭配 jax.tree.map()
時,輸入的結構必須完全相符。也就是說,列表必須具有相同數量的元素,字典必須具有相同的金鑰等等。
jax.tree.map
搭配 ML 模型參數的範例#
此範例示範了在訓練簡單的多層感知器 (MLP) 時,pytree 操作如何有用。
首先定義初始模型參數
import numpy as np
def init_mlp_params(layer_widths):
params = []
for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
params.append(
dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
biases=np.ones(shape=(n_out,))
)
)
return params
params = init_mlp_params([1, 128, 128, 1])
使用 jax.tree.map()
檢查初始參數的形狀
jax.tree.map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
{'biases': (128,), 'weights': (128, 128)},
{'biases': (1,), 'weights': (128, 1)}]
接下來,定義用於訓練 MLP 模型的函式
# Define the forward pass.
def forward(params, x):
*hidden, last = params
for layer in hidden:
x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
return x @ last['weights'] + last['biases']
# Define the loss function.
def loss_fn(params, x, y):
return jnp.mean((forward(params, x) - y) ** 2)
# Set the learning rate.
LEARNING_RATE = 0.0001
# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
# Calculate the gradients with `jax.grad`.
grads = jax.grad(loss_fn)(params, x, y)
# Note that `grads` is a pytree with the same structure as `params`.
# `jax.grad` is one of many JAX functions that has
# built-in support for pytrees.
# This is useful - you can apply the SGD update using JAX pytree utilities.
return jax.tree.map(
lambda p, g: p - LEARNING_RATE * g, params, grads
)
自訂 pytree 節點#
本節說明如何在 JAX 中擴充將被視為 pytree (pytree 節點) 中的內部節點的 Python 類型集,方法是搭配 jax.tree.map()
使用 jax.tree_util.register_pytree_node()
。
為什麼您會需要這個?在先前的範例中,pytree 顯示為列表、元組和字典,其他所有內容都作為 pytree 葉節點。這是因為如果您定義自己的容器類別,除非您向 JAX 登錄它,否則它將被視為 pytree 葉節點。即使您的容器類別內部有樹狀結構,情況也是如此。例如
class Special(object):
def __init__(self, x, y):
self.x = x
self.y = y
jax.tree.leaves([
Special(0, 1),
Special(2, 4),
])
[<__main__.Special at 0x7fa6fc4db370>, <__main__.Special at 0x7fa6fc4dbeb0>]
因此,如果您嘗試使用 jax.tree.map()
,並期望葉節點是容器內部的元素,您將收到錯誤
jax.tree.map(lambda x: x + 1,
[
Special(0, 1),
Special(2, 4)
])
TypeError: unsupported operand type(s) for +: 'Special' and 'int'
作為解決方案,JAX 允許透過全域類型登錄檔來擴充要視為內部 pytree 節點的類型集。此外,還會遞迴式地遍歷已登錄類型的值。
首先,使用 jax.tree_util.register_pytree_node()
登錄新類型
from jax.tree_util import register_pytree_node
class RegisteredSpecial(Special):
def __repr__(self):
return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
def special_flatten(v):
"""Specifies a flattening recipe.
Params:
v: The value of the registered type to flatten.
Returns:
A pair of an iterable with the children to be flattened recursively,
and some opaque auxiliary data to pass back to the unflattening recipe.
The auxiliary data is stored in the treedef for use during unflattening.
The auxiliary data could be used, for example, for dictionary keys.
"""
children = (v.x, v.y)
aux_data = None
return (children, aux_data)
def special_unflatten(aux_data, children):
"""Specifies an unflattening recipe.
Params:
aux_data: The opaque data that was specified during flattening of the
current tree definition.
children: The unflattened children
Returns:
A reconstructed object of the registered type, using the specified
children and auxiliary data.
"""
return RegisteredSpecial(*children)
# Global registration
register_pytree_node(
RegisteredSpecial,
special_flatten, # Instruct JAX what are the children nodes.
special_unflatten # Instruct JAX how to pack back into a `RegisteredSpecial`.
)
現在您可以遍歷特殊容器結構
jax.tree.map(lambda x: x + 1,
[
RegisteredSpecial(0, 1),
RegisteredSpecial(2, 4),
])
[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]
現代 Python 配備了實用的工具,可讓定義容器更容易。有些可以直接與 JAX 搭配使用,但其他則需要更謹慎地處理。
例如,Python NamedTuple
子類別不需要登錄即可視為 pytree 節點類型
from typing import NamedTuple, Any
class MyOtherContainer(NamedTuple):
name: str
a: Any
b: Any
c: Any
# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree.leaves([
MyOtherContainer('Alice', 1, 2, 3),
MyOtherContainer('Bob', 4, 5, 6)
])
['Alice', 1, 2, 3, 'Bob', 4, 5, 6]
請注意,name
欄位現在顯示為葉節點,因為所有元組元素都是子節點。這是當您不必以複雜的方式登錄類別時會發生的情況。
與 NamedTuple
子類別不同,以 @dataclass
修飾的類別不會自動成為 pytree。但是,可以使用 jax.tree_util.register_dataclass()
修飾器將它們登錄為 pytree
from dataclasses import dataclass
import functools
@functools.partial(jax.tree_util.register_dataclass,
data_fields=['a', 'b', 'c'],
meta_fields=['name'])
@dataclass
class MyDataclassContainer(object):
name: str
a: Any
b: Any
c: Any
# MyDataclassContainer is now a pytree node.
jax.tree.leaves([
MyDataclassContainer('apple', 5.3, 1.2, jnp.zeros([4])),
MyDataclassContainer('banana', np.array([3, 4]), -1., 0.)
])
[5.3, 1.2, Array([0., 0., 0., 0.], dtype=float32), array([3, 4]), -1.0, 0.0]
請注意,name
欄位不會顯示為葉節點。這是因為我們將它包含在 jax.tree_util.register_dataclass()
的 meta_fields
引數中,表示它應該被視為中繼資料/輔助資料,就像上面 RegisteredSpecial
中的 aux_data
一樣。現在,MyDataclassContainer
的執行個體可以傳遞到 JIT 函式中,而 name
將被視為靜態 (如需靜態引數的詳細資訊,請參閱將引數標記為靜態)
@jax.jit
def f(x: MyDataclassContainer | MyOtherContainer):
return x.a + x.b
# Works fine! `mdc.name` is static.
mdc = MyDataclassContainer('mdc', 1, 2, 3)
y = f(mdc)
將此與 NamedTuple
子類別 MyOtherContainer
進行比較。由於 name
欄位是 pytree 葉節點,因此 JIT 預期它可以轉換為 jax.Array
,並且以下程式碼會引發錯誤
moc = MyOtherContainer('moc', 1, 2, 3)
y = f(moc)
TypeError: Error interpreting argument to <function f at 0x7fa6fc5097e0> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path x.name.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.
Pytree 和 JAX 轉換#
許多 JAX 函式 (例如 jax.lax.scan()
) 對陣列的 pytree 進行操作。此外,所有 JAX 函式轉換都可以應用於接受陣列 pytree 作為輸入並產生陣列 pytree 作為輸出的函式。
某些 JAX 函式轉換採用選用參數,這些參數指定應如何處理某些輸入或輸出值 (例如 jax.vmap()
的 in_axes
和 out_axes
引數)。這些參數也可以是 pytree,並且其結構必須對應於對應引數的 pytree 結構。特別是,為了能夠將這些參數 pytree 中的葉節點與引數 pytree 中的值「匹配」,參數 pytree 通常會限制為引數 pytree 的樹狀結構前綴。
例如,如果您將以下輸入傳遞至 jax.vmap()
(請注意,函式的輸入引數會視為元組)
vmap(f, in_axes=(a1, {"k1": a2, "k2": a3}))
那麼您可以使用以下 in_axes
pytree 來指定僅映射 k2
引數 (axis=0
),而其餘引數則不進行映射 (axis=None
)
vmap(f, in_axes=(None, {"k1": None, "k2": 0}))
選用參數 pytree 結構必須與主要輸入 pytree 的結構相符。但是,選用參數可以選擇性地指定為「前綴」pytree,這表示單個葉節點值可以應用於整個子 pytree。
例如,如果您具有與上述相同的 jax.vmap()
輸入,但希望僅映射字典引數,則可以使用
vmap(f, in_axes=(None, 0)) # equivalent to (None, {"k1": 0, "k2": 0})
或者,如果您希望映射每個引數,則可以撰寫應用於整個引數元組 pytree 的單個葉節點值
vmap(f, in_axes=0) # equivalent to (0, {"k1": 0, "k2": 0})
這恰好是 jax.vmap()
的預設 in_axes
值。
相同的邏輯適用於其他參考轉換函式的特定輸入或輸出值的選用參數,例如 jax.vmap()
中的 out_axes
。
顯式金鑰路徑#
在 pytree 中,每個葉節點都有一個金鑰路徑。葉節點的金鑰路徑是金鑰的 list
,其中列表的長度等於葉節點在 pytree 中的深度。每個金鑰都是一個 可雜湊物件,表示對應 pytree 節點類型的索引。金鑰的類型取決於 pytree 節點類型;例如,dict
的金鑰類型與 tuple
的金鑰類型不同。
對於內建 pytree 節點類型,任何 pytree 節點執行個體的金鑰集都是唯一的。對於包含具有此屬性的節點的 pytree,每個葉節點的金鑰路徑都是唯一的。
JAX 具有以下 jax.tree_util.*
方法,可用於處理金鑰路徑
jax.tree_util.tree_flatten_with_path()
:運作方式類似於jax.tree.flatten()
,但會傳回金鑰路徑。jax.tree_util.tree_map_with_path()
:運作方式類似於jax.tree.map()
,但函式也會將金鑰路徑作為引數。jax.tree_util.keystr()
:給定一般金鑰路徑,傳回易於閱讀的字串運算式。
例如,一個使用案例是列印與特定葉節點值相關的偵錯資訊
import collections
ATuple = collections.namedtuple("ATuple", ('name'))
tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)
for key_path, value in flattened:
print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo
為了表示金鑰路徑,JAX 為內建 pytree 節點類型提供了一些預設金鑰類型,即
SequenceKey(idx: int)
:適用於列表和元組。DictKey(key: Hashable)
:適用於字典。GetAttrKey(name: str)
:適用於namedtuple
和最好是自訂 pytree 節點 (更多資訊請參閱下一節)
您可以自由地為您的自訂節點定義自己的金鑰類型。它們將與 jax.tree_util.keystr()
搭配使用,只要它們的 __str__()
方法也使用易於閱讀的運算式覆寫即可。
for key_path, _ in flattened:
print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
Key path of tree[0]: (SequenceKey(idx=0),)
Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))
常見的 pytree 陷阱#
本節涵蓋使用 JAX pytree 時遇到的一些最常見問題 («陷阱»)。
將 pytree 節點誤認為葉節點#
需要注意的一個常見陷阱是意外引入樹狀結構節點而不是葉節點
a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]
# Try to make another pytree with ones instead of zeros.
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes)
[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
(Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]
這裡發生的情況是陣列的 shape
是一個元組,它是 pytree 節點,其元素作為葉節點。因此,在映射中,不是在例如 (2, 3)
上呼叫 jnp.ones
,而是在 2
和 3
上呼叫。
解決方案將取決於具體情況,但有兩個廣泛適用的選項
重新撰寫程式碼以避免中繼
jax.tree.map()
。將元組轉換為 NumPy 陣列 (
np.array
) 或 JAX NumPy 陣列 (jnp.array
),這會使整個序列成為葉節點。
jax.tree_util
對 None
的處理#
jax.tree_util
函式將 None
視為缺少 pytree 節點,而不是葉節點
jax.tree.leaves([None, None, None])
[]
若要將 None
視為葉節點,您可以使用 is_leaf
引數
jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)
[None, None, None]
自訂 pytree 和使用非預期值初始化#
使用者定義的 pytree 物件的另一個常見陷阱是,JAX 轉換偶爾會使用非預期值初始化它們,因此在初始化時完成的任何輸入驗證都可能會失敗。例如
class MyTree:
def __init__(self, a):
self.a = jnp.asarray(a)
register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
lambda _, args: MyTree(*args))
tree = MyTree(jnp.arange(5.0))
jax.vmap(lambda x: x)(tree) # Error because object() is passed to `MyTree`.
TypeError: Value '<object object at 0x7fa6fc58c710>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to `MyTree`.
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:5811: FutureWarning: None encountered in jnp.array(); this is currently treated as NaN. In the future this will result in an error.
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
TypeError: Value '<object object at 0x7fa6fc58cc00>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
在第一個使用
jax.vmap(...)(tree)
的情況下,JAX 的內部機制會使用object()
值的陣列來推斷樹狀結構的結構在第二個使用
jax.jacobian(...)(tree)
的情況下,將樹狀結構映射到樹狀結構的函式的 Jacobian 定義為樹狀結構的樹狀結構。
潛在解決方案 1
自訂 pytree 類別的
__init__
和__new__
方法通常應避免執行任何陣列轉換或其他輸入驗證,否則請預期並處理這些特殊情況。例如
class MyTree:
def __init__(self, a):
if not (type(a) is object or a is None or isinstance(a, MyTree)):
a = jnp.asarray(a)
self.a = a
潛在解決方案 2
設計您的自訂
tree_unflatten
函數,使其避免呼叫__init__
。如果您選擇這條路徑,請確保您的tree_unflatten
函數與__init__
保持同步,如果程式碼更新時。範例
def tree_unflatten(aux_data, children):
del aux_data # Unused in this class.
obj = object.__new__(MyTree)
obj.a = a
return obj
常見的 pytree 模式#
本節涵蓋 JAX pytree 中一些最常見的模式。
使用 jax.tree.map
和 jax.tree.transpose
轉置 pytree#
為了轉置 pytree(將樹狀結構列表轉換為樹狀結構的列表),JAX 有兩個函數:jax.tree.map()
(更基礎)和 jax.tree.transpose()
(更靈活、複雜且詳細)。
選項 1: 使用 jax.tree.map()
。這是一個範例
def tree_transpose(list_of_trees):
"""
Converts a list of trees of identical structure into a single tree of lists.
"""
return jax.tree.map(lambda *xs: list(xs), *list_of_trees)
# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)
{'obs': [3, 4], 't': [1, 2]}
選項 2: 對於更複雜的轉置,使用 jax.tree.transpose()
,它更詳細,但允許您指定內部和外部 pytree 的結構以獲得更大的靈活性。例如
jax.tree.transpose(
outer_treedef = jax.tree.structure([0 for e in episode_steps]),
inner_treedef = jax.tree.structure(episode_steps[0]),
pytree_to_transpose = episode_steps
)
{'obs': [3, 4], 't': [1, 2]}