jax.export.symbolic_args_specs#

jax.export.symbolic_args_specs(args, shapes_specs, constraints=(), scope=None)[原始碼]#

export 建構 jax.ShapeDtypeSpec 參數規格的 pytree。

請參閱 jax.export.symbolic_shape() 的文件和[形狀多型性文件](https://jax.dev.org.tw/en/latest/export/shape_poly.html) 以取得詳細資訊。

參數:
傳回值:符合 args 的 jax.ShapeDTypeStruct pytree,其形狀

shapes_specs 指定的符號維度取代。