jax.lax.platform_dependent#

jax.lax.platform_dependent(*args, default=None, **per_platform)[source]#

暫存平台特定程式碼。

在 JAX 中,實際執行運算的平台非常晚才確定,例如,根據資料所在位置而定。當使用 AOT 降低或序列化時,運算可能會在不同的機器上編譯和執行,甚至在降低時無法使用的平台上執行。這表示使用 Python 條件式編寫平台相關程式碼並不安全,例如,根據目前的預設 JAX 平台。相反地,可以使用 platform_dependent

用法

def cpu_code(*args): ...
def tpu_code(*args): ...
def other_platforms_code(*args): ...
res = platform_dependent(*args, cpu=cpu_code, tpu=tpu_code,
                         default=other_platforms_code)

當暫存程式碼在 CPU 上執行時,這等同於 cpu_code(*args);在 TPU 上執行時,等同於 tpu_code(*args);在任何其他平台上執行時,等同於 other_platforms_code(*args)。與 Python 條件式不同,所有替代方案都會被追蹤並暫存到 Jaxpr。這類似於 switch(),並根據它實作,從中繼承轉換下的行為。

switch() 不同,執行內容的選擇時間點較早:在大多數情況下,在降低平台已知時進行降低期間;在罕見的多平台降低和序列化情況下,StableHLO 程式碼將包含實際平台的條件式。此條件式會在編譯平台已知時,在編譯之前及時解析。這表示編譯器實際上永遠不會看到條件式。

參數:
  • *args (Any) – 傳遞至每個分支的 JAX 陣列。可能是 PyTrees。

  • **per_platform (Callable[..., _T]) – 用於不同平台的分支。這些分支是使用 *args 叫用的 JAX 可呼叫物件。關鍵字是平台名稱,例如 ‘cpu’、‘tpu’、‘cuda’、‘rocm’。

  • default (Callable[..., _T] | None | None) – 用於未在 per_platform 中提及的平台的選用預設分支。如果沒有 default,當程式碼針對未在 per_platform 中提及的平台降低時,將會發生錯誤。

傳回:

per_platform[execution_platform](*args)