jax.tree_util.register_dataclass#
- jax.tree_util.register_dataclass(nodetype, data_fields=None, meta_fields=None, drop_fields=())[原始碼]#
擴展了在 pytrees 中被視為內部節點的類型集合。
這與
register_pytree_with_keys_class
的不同之處在於,C++ 註冊表使用最佳化的 C++ dataclass 內建函式,而不是引數函式。請參閱 擴展 pytrees 以取得更多關於註冊 pytrees 的資訊。
- 參數:
nodetype (Typ) – 一個 Python 類型,將其視為內部 pytree 節點。這假定具有
dataclass
的語意:也就是說,類別屬性代表物件狀態的整體,並且可以作為關鍵字傳遞給類別建構子以建立物件的副本。所有已定義的屬性都應列在meta_fields
或data_fields
中。meta_fields (Sequence[str] | None | None) – 元資料欄位名稱:這些是在 pytree 傳遞至
jax.jit()
時將被視為 {term}`static` 的屬性。meta_fields
僅在nodetype
是 dataclass 時才是選填的,在這種情況下,可以使用dataclasses.field()
將個別欄位標記為靜態 (請參閱以下範例)。元資料欄位必須是靜態、可雜湊、不可變的物件,因為這些物件用於產生 JIT 快取金鑰。特別是,元資料欄位不能包含jax.Array
或numpy.ndarray
物件。data_fields (Sequence[str] | None | None) – 資料欄位名稱:這些是在 pytree 傳遞至
jax.jit()
時將被視為非靜態的屬性。data_fields
僅在nodetype
是 dataclass 時才是選填的,在這種情況下,除非使用dataclasses.field()
標記,否則欄位會被假定為資料欄位 (請參閱以下範例)。資料欄位必須是與 JAX 相容的物件,例如陣列 (jax.Array
或numpy.ndarray
)、純量或 leaves 是陣列或純量的 pytrees。請注意,None
是有效的資料欄位,因為 JAX 將其識別為空的 pytree。drop_fields (Sequence[str])
- 傳回值:
輸入類別
nodetype
在新增至 JAX 的 pytree 註冊表後,會保持不變傳回,因此register_dataclass()
可以作為裝飾器使用。- 傳回類型:
Typ
範例
在 JAX v0.4.35 或更舊版本中,您必須指定
data_fields
和meta_fields
才能使用此裝飾器>>> import jax >>> from dataclasses import dataclass >>> from functools import partial ... >>> @partial(jax.tree_util.register_dataclass, ... data_fields=['x', 'y'], ... meta_fields=['op']) ... @dataclass ... class MyStruct: ... x: jax.Array ... y: jax.Array ... op: str ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
從 JAX v0.4.36 開始,對於
dataclass()
輸入,data_fields
和meta_fields
引數是選填的,欄位預設為data_fields
,除非使用 static 元資料在dataclasses.field()
中標記為靜態。>>> import jax >>> from dataclasses import dataclass, field ... >>> @jax.tree_util.register_dataclass ... @dataclass ... class MyStruct: ... x: jax.Array # defaults to non-static data field ... y: jax.Array # defaults to non-static data field ... op: str = field(metadata=dict(static=True)) # marked as static meta field. ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
一旦註冊此類別,即可將其與
jax.tree
和jax.tree_util
中的函式搭配使用>>> leaves, treedef = jax.tree.flatten(m) >>> leaves [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)] >>> treedef PyTreeDef(CustomNode(MyStruct[('add',)], [*, *])) >>> jax.tree.unflatten(treedef, leaves) MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
特別是,此註冊允許
m
無縫地傳遞通過以jax.jit()
和其他 JAX 轉換包裝的程式碼,其中data_fields
被視為動態引數,而meta_fields
被視為靜態引數>>> @jax.jit ... def compiled_func(m): ... if m.op == 'add': ... return m.x + m.y ... else: ... raise ValueError(f"{m.op=}") ... >>> compiled_func(m) Array([1., 2., 3.], dtype=float32)