jax.stages
模組#
編譯執行過程各階段的介面。
即時編譯的 JAX 轉換,例如 jax.jit
和 jax.pmap
,也支援一種常見的預先顯式降低和編譯方式。此模組定義了表示此過程各階段的型別。
更多資訊,請參閱 AOT 導覽。
類別#
- class jax.stages.Wrapped(*args, **kwargs)[原始碼]#
準備好被追蹤、降低和編譯的函式。
此協定反映了
jax.jit
等函式的輸出。呼叫它會導致 JIT(即時)降低、編譯和執行。它也可以在編譯之前顯式降低,並且結果可以在執行之前編譯。
- class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)[原始碼]#
專門針對引數型別和值的函式降低。
降低是準備好進行編譯的計算。此類別攜帶降低,以及稍後編譯和執行它所需的其餘資訊。它還為查詢跨 JAX 各種降低路徑的降低計算屬性(
jit()
、pmap()
等)提供常見的 API。- 參數:
lowering (XlaLowering)
args_info (Any)
out_tree (tree_util.PyTreeDef)
no_kwargs (bool)
- as_text(dialect=None, *, debug_info=False)[原始碼]#
此降低的可讀文字表示形式。
適用於視覺化和除錯目的。這不需要是有效或可靠的序列化。如果您想要可靠且可攜式的序列化,請使用 jax.export。
- compile(compiler_options=None)[原始碼]#
編譯,傳回對應的
Compiled
實例。- 參數:
compiler_options (CompilerOptions | None | None)
- 傳回型別:
- compiler_ir(dialect=None)[原始碼]#
此降低的任意物件表示形式。
適用於除錯目的。這不是有效或可靠的序列化。輸出無法保證跨調用的一致性。如果您想要可靠且可攜式的序列化,請使用 jax.export。
如果不可用,則傳回
None
,例如基於 backend、編譯器或執行階段。- 參數:
dialect (str | None | None) – 可選字串,指定降低方言(例如 "stablehlo" 或 "hlo")。
- 傳回型別:
Any | None
- class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)[原始碼]#
專門針對型別/值的函式的編譯表示形式。
編譯的計算與可執行檔以及執行它所需的其餘資訊相關聯。它還為查詢跨 JAX 各種編譯路徑和 backend 的編譯計算屬性提供常見的 API。
- 參數:
args_info (Any)
out_tree (tree_util.PyTreeDef)
- as_text()[原始碼]#
此可執行檔的可讀文字表示形式。
適用於視覺化和除錯目的。這不是有效或可靠的序列化。
如果不可用,則傳回
None
,例如基於 backend、編譯器或執行階段。- 傳回型別:
str | None
- cost_analysis()[原始碼]#
執行成本估計的摘要。
適用於視覺化和除錯目的。此物件輸出是一些簡單的資料結構,可以輕鬆列印或序列化(例如,具有數值葉節點的巢狀字典、列表和元組)。但是,其結構可以是任意的:它可能在 JAX 和 jaxlib 的版本之間,甚至在跨調用之間不一致。
如果不可用,則傳回
None
,例如基於 backend、編譯器或執行階段。- 傳回型別:
Any | None