jax.stages 模組#

編譯執行過程各階段的介面。

即時編譯的 JAX 轉換,例如 jax.jitjax.pmap,也支援一種常見的預先顯式降低和編譯方式。此模組定義了表示此過程各階段的型別。

更多資訊,請參閱 AOT 導覽

類別#

class jax.stages.Wrapped(*args, **kwargs)[原始碼]#

準備好被追蹤、降低和編譯的函式。

此協定反映了 jax.jit 等函式的輸出。呼叫它會導致 JIT(即時)降低、編譯和執行。它也可以在編譯之前顯式降低,並且結果可以在執行之前編譯。

__call__(*args, **kwargs)[原始碼]#

執行包裝函式,根據需要進行降低和編譯。

lower(*args, **kwargs)[原始碼]#

針對給定的引數顯式降低此函式。

降低的函式會從 Python 中暫存出來,並轉換為編譯器的輸入語言,可能是以 backend 相依的方式。它已準備好進行編譯,但尚未編譯。

傳回:

表示降低的 Lowered 實例。

傳回型別:

Lowered

trace(*args, **kwargs)[原始碼]#

針對給定的引數顯式追蹤此函式。

追蹤的函式會從 Python 中暫存出來,並轉換為 jaxpr。它已準備好進行降低,但尚未降低。

傳回:

表示追蹤的 Traced 實例。

傳回型別:

Traced

class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)[原始碼]#

專門針對引數型別和值的函式降低。

降低是準備好進行編譯的計算。此類別攜帶降低,以及稍後編譯和執行它所需的其餘資訊。它還為查詢跨 JAX 各種降低路徑的降低計算屬性(jit()pmap() 等)提供常見的 API。

參數:
  • lowering (XlaLowering)

  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

  • no_kwargs (bool)

as_text(dialect=None, *, debug_info=False)[原始碼]#

此降低的可讀文字表示形式。

適用於視覺化和除錯目的。這不需要是有效或可靠的序列化。如果您想要可靠且可攜式的序列化,請使用 jax.export

參數:
  • dialect (str | None | None) – 可選字串,指定降低方言(例如 "stablehlo" 或 "hlo")。

  • debug_info (bool) – 是否包含除錯資訊,例如,原始碼位置。

傳回型別:

str

compile(compiler_options=None)[原始碼]#

編譯,傳回對應的 Compiled 實例。

參數:

compiler_options (CompilerOptions | None | None)

傳回型別:

Compiled

compiler_ir(dialect=None)[原始碼]#

此降低的任意物件表示形式。

適用於除錯目的。這不是有效或可靠的序列化。輸出無法保證跨調用的一致性。如果您想要可靠且可攜式的序列化,請使用 jax.export

如果不可用,則傳回 None,例如基於 backend、編譯器或執行階段。

參數:

dialect (str | None | None) – 可選字串,指定降低方言(例如 "stablehlo" 或 "hlo")。

傳回型別:

Any | None

cost_analysis()[原始碼]#

執行成本估計的摘要。

適用於視覺化和除錯目的。此物件輸出是一些簡單的資料結構,可以輕鬆列印或序列化(例如,具有數值葉節點的巢狀字典、列表和元組)。但是,其結構可以是任意的:它可能在 JAX 和 jaxlib 的版本之間,甚至在跨調用之間不一致。

如果不可用,則傳回 None,例如基於 backend、編譯器或執行階段。

傳回型別:

Any | None

property in_tree: tree_util.PyTreeDef[原始碼]#

配對的樹狀結構(位置引數、關鍵字引數)。

class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)[原始碼]#

專門針對型別/值的函式的編譯表示形式。

編譯的計算與可執行檔以及執行它所需的其餘資訊相關聯。它還為查詢跨 JAX 各種編譯路徑和 backend 的編譯計算屬性提供常見的 API。

參數:
  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

__call__(*args, **kwargs)[原始碼]#

將 self 作為函式呼叫。

as_text()[原始碼]#

此可執行檔的可讀文字表示形式。

適用於視覺化和除錯目的。這不是有效或可靠的序列化。

如果不可用,則傳回 None,例如基於 backend、編譯器或執行階段。

傳回型別:

str | None

cost_analysis()[原始碼]#

執行成本估計的摘要。

適用於視覺化和除錯目的。此物件輸出是一些簡單的資料結構,可以輕鬆列印或序列化(例如,具有數值葉節點的巢狀字典、列表和元組)。但是,其結構可以是任意的:它可能在 JAX 和 jaxlib 的版本之間,甚至在跨調用之間不一致。

如果不可用,則傳回 None,例如基於 backend、編譯器或執行階段。

傳回型別:

Any | None

property in_tree: tree_util.PyTreeDef[原始碼]#

配對的樹狀結構(位置引數、關鍵字引數)。

memory_analysis()[原始碼]#

記憶體需求估計的摘要。

適用於視覺化和除錯目的。此物件輸出是一些簡單的資料結構,可以輕鬆列印或序列化(例如,具有數值葉節點的巢狀字典、列表和元組)。但是,其結構可以是任意的:它可能在 JAX 和 jaxlib 的版本之間,甚至在跨調用之間不一致。

如果不可用,則傳回 None,例如基於 backend、編譯器或執行階段。

傳回型別:

Any | None

runtime_executable()[原始碼]#

此可執行檔的任意物件表示形式。

適用於除錯目的。這不是有效或可靠的序列化。輸出無法保證跨調用的一致性。

如果不可用,則傳回 None,例如基於 backend、編譯器或執行階段。

傳回型別:

Any | None