jax.tree_util.treedef_tuple#
- jax.tree_util.treedef_tuple(treedefs)[原始碼]#
從子 treedef 的可迭代物件建立 tuple treedef。
- 參數:
treedefs (Iterable[PyTreeDef]) – PyTree 結構的可迭代物件
- 返回:
代表結構 tuple 的單個 treedef
- 返回類型:
PyTreeDef
範例
>>> import jax >>> x = [1, 2, 3] >>> y = {'a': 4, 'b': 5} >>> x_tree = jax.tree.structure(x) >>> y_tree = jax.tree.structure(y) >>> xy_tree = jax.tree_util.treedef_tuple([x_tree, y_tree]) >>> xy_tree == jax.tree.structure((x, y)) True