jax.export.symbolic_args_specs#
- jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None)[原始碼]#
為 export 建構 jax.ShapeDtypeSpec 參數規格的 pytree。
請參閱
jax.export.symbolic_shape()
的文件和[形狀多型性文件](https://jax.dev.org.tw/en/latest/export/shape_poly.html) 以取得詳細資訊。- 參數:
args – 參數的 pytree。這些可以是 jax.Array 或 jax.ShapeDTypeSpec。它們用於學習參數的 pytree 結構、其 dtype,並在 shapes_specs 包含預留位置時填入實際形狀。請注意,僅從 args 中使用 shapes_specs 作為預留位置的形狀維度。
shapes_specs – 應為 None (所有參數都具有靜態形狀)、單一字串 (請參閱
jax.export.symbolic_shape()
的 shape_spec;適用於所有參數),或符合 args 字首的 pytree。請參閱[如何將選用參數與 pytree 匹配](https://jax.dev.org.tw/en/latest/pytrees.html#applying-optional-parameters-to-pytrees)。constraints (Sequence[str]) – 與
jax.export.symbolic_shape()
相同。scope (SymbolicScope | None | None) – 與
jax.export.symbolic_shape()
相同。
- 傳回值:符合 args 的 jax.ShapeDTypeStruct pytree,其形狀
由 shapes_specs 指定的符號維度取代。