jax.tree_util.register_dataclass#

jax.tree_util.register_dataclass(nodetype, data_fields=None, meta_fields=None, drop_fields=())[原始碼]#

擴展了在 pytrees 中被視為內部節點的類型集合。

這與 register_pytree_with_keys_class 的不同之處在於,C++ 註冊表使用最佳化的 C++ dataclass 內建函式,而不是引數函式。

請參閱 擴展 pytrees 以取得更多關於註冊 pytrees 的資訊。

參數:
  • nodetype (Typ) – 一個 Python 類型,將其視為內部 pytree 節點。這假定具有 dataclass 的語意:也就是說,類別屬性代表物件狀態的整體,並且可以作為關鍵字傳遞給類別建構子以建立物件的副本。所有已定義的屬性都應列在 meta_fieldsdata_fields 中。

  • meta_fields (Sequence[str] | None | None) – 元資料欄位名稱:這些是在 pytree 傳遞至 jax.jit() 時將被視為 {term}`static` 的屬性。meta_fields 僅在 nodetype 是 dataclass 時才是選填的,在這種情況下,可以使用 dataclasses.field() 將個別欄位標記為靜態 (請參閱以下範例)。元資料欄位必須是靜態、可雜湊、不可變的物件,因為這些物件用於產生 JIT 快取金鑰。特別是,元資料欄位不能包含 jax.Arraynumpy.ndarray 物件。

  • data_fields (Sequence[str] | None | None) – 資料欄位名稱:這些是在 pytree 傳遞至 jax.jit() 時將被視為非靜態的屬性。data_fields 僅在 nodetype 是 dataclass 時才是選填的,在這種情況下,除非使用 dataclasses.field() 標記,否則欄位會被假定為資料欄位 (請參閱以下範例)。資料欄位必須是與 JAX 相容的物件,例如陣列 (jax.Arraynumpy.ndarray)、純量或 leaves 是陣列或純量的 pytrees。請注意,None 是有效的資料欄位,因為 JAX 將其識別為空的 pytree。

  • drop_fields (Sequence[str])

傳回值:

輸入類別 nodetype 在新增至 JAX 的 pytree 註冊表後,會保持不變傳回,因此 register_dataclass() 可以作為裝飾器使用。

傳回類型:

Typ

範例

在 JAX v0.4.35 或更舊版本中,您必須指定 data_fieldsmeta_fields 才能使用此裝飾器

>>> import jax
>>> from dataclasses import dataclass
>>> from functools import partial
...
>>> @partial(jax.tree_util.register_dataclass,
...          data_fields=['x', 'y'],
...          meta_fields=['op'])
... @dataclass
... class MyStruct:
...   x: jax.Array
...   y: jax.Array
...   op: str
...
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

從 JAX v0.4.36 開始,對於 dataclass() 輸入,data_fieldsmeta_fields 引數是選填的,欄位預設為 data_fields,除非使用 static 元資料在 dataclasses.field() 中標記為靜態。

>>> import jax
>>> from dataclasses import dataclass, field
...
>>> @jax.tree_util.register_dataclass
... @dataclass
... class MyStruct:
...   x: jax.Array  # defaults to non-static data field
...   y: jax.Array  # defaults to non-static data field
...   op: str = field(metadata=dict(static=True))  # marked as static meta field.
...
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

一旦註冊此類別,即可將其與 jax.treejax.tree_util 中的函式搭配使用

>>> leaves, treedef = jax.tree.flatten(m)
>>> leaves
[Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
>>> treedef
PyTreeDef(CustomNode(MyStruct[('add',)], [*, *]))
>>> jax.tree.unflatten(treedef, leaves)
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

特別是,此註冊允許 m 無縫地傳遞通過以 jax.jit() 和其他 JAX 轉換包裝的程式碼,其中 data_fields 被視為動態引數,而 meta_fields 被視為靜態引數

>>> @jax.jit
... def compiled_func(m):
...   if m.op == 'add':
...     return m.x + m.y
...   else:
...     raise ValueError(f"{m.op=}")
...
>>> compiled_func(m)
Array([1., 2., 3.], dtype=float32)