jax.export.register_pytree_node_serialization#
- jax.export.register_pytree_node_serialization(nodetype, *, serialized_name, serialize_auxdata, deserialize_auxdata, from_children=None)[原始碼]#
為序列化和反序列化註冊自訂 PyTree 節點。
您必須先使用此函式,才能序列化和反序列化本機不支援類型的 PyTree 節點。我們會序列化 Exported 的 in_tree 和 out_tree 欄位的 PyTree 節點,這些是匯出函式呼叫慣例的一部分。
此函式必須在呼叫 jax.tree_util.register_pytree_node 之後呼叫 (除了 collections.namedtuple,其不需要呼叫 register_pytree_node)。
- 參數:
nodetype (type[T]) – 我們要序列化其 PyTree 節點的類型。嘗試為 nodetype 註冊多個序列化是錯誤的。
serialized_name (str) – 一個字串,將出現在序列化中,並將用於在反序列化期間查找註冊。嘗試為 serialized_name 註冊多個序列化是錯誤的。
serialize_auxdata (_SerializeAuxData) – 序列化 PyTree 輔助資料 (由 jax.tree_util.register_pytree_node 的 flatten_func 引數傳回)。
deserialize_auxdata (_DeserializeAuxData) – 反序列化由 serialize_auxdata 序列化的輔助資料。
from_children (_BuildFromChildren | None | None) – 如果存在,這是一個函式,它接受 deserialize_auxdata 的結果以及一些子節點,並建立 nodetype 的實例。這類似於傳遞給 jax.tree_util.register_pytree_node 的 unflatten_func。如果不存在,我們會查找並使用 unflatten_func。這是 collections.namedtuple 所需的,它沒有 register_pytree_node,但覆寫該函式可能很有用。請注意,from_children 的結果僅與 jax.tree_util.tree_structure 一起使用以建構正確的 PyTree 節點,它不被用於建構序列化函式的輸出。
- 傳回:
作為 nodetype 傳遞的相同類型,以便此函式可以用作類別裝飾器。
- 傳回類型:
type[T]