jax.lax.platform_dependent#
- jax.lax.platform_dependent(*args, default=None, **per_platform)[source]#
暫存平台特定程式碼。
在 JAX 中,實際執行運算的平台非常晚才確定,例如,根據資料所在位置而定。當使用 AOT 降低或序列化時,運算可能會在不同的機器上編譯和執行,甚至在降低時無法使用的平台上執行。這表示使用 Python 條件式編寫平台相關程式碼並不安全,例如,根據目前的預設 JAX 平台。相反地,可以使用
platform_dependent
用法
def cpu_code(*args): ... def tpu_code(*args): ... def other_platforms_code(*args): ... res = platform_dependent(*args, cpu=cpu_code, tpu=tpu_code, default=other_platforms_code)
當暫存程式碼在 CPU 上執行時,這等同於
cpu_code(*args)
;在 TPU 上執行時,等同於tpu_code(*args)
;在任何其他平台上執行時,等同於other_platforms_code(*args)
。與 Python 條件式不同,所有替代方案都會被追蹤並暫存到 Jaxpr。這類似於switch()
,並根據它實作,從中繼承轉換下的行為。與
switch()
不同,執行內容的選擇時間點較早:在大多數情況下,在降低平台已知時進行降低期間;在罕見的多平台降低和序列化情況下,StableHLO 程式碼將包含實際平台的條件式。此條件式會在編譯平台已知時,在編譯之前及時解析。這表示編譯器實際上永遠不會看到條件式。- 參數:
*args (Any) – 傳遞至每個分支的 JAX 陣列。可能是 PyTrees。
**per_platform (Callable[..., _T]) – 用於不同平台的分支。這些分支是使用
*args
叫用的 JAX 可呼叫物件。關鍵字是平台名稱,例如 ‘cpu’、‘tpu’、‘cuda’、‘rocm’。default (Callable[..., _T] | None | None) – 用於未在
per_platform
中提及的平台的選用預設分支。如果沒有default
,當程式碼針對未在per_platform
中提及的平台降低時,將會發生錯誤。
- 傳回:
值
per_platform[execution_platform](*args)
。