jax.experimental.pjit 模組#

API#

jax.experimental.pjit.pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None, compiler_options=None)[原始碼]#

使 fun 經過編譯,並自動跨多個裝置進行分割。

注意:此函式現在等同於 jax.jit,請改用該函式。傳回的函式具有與 fun 等效的語義,但會編譯成跨多個裝置(例如多個 GPU 或多個 TPU 核心)執行的 XLA 計算。如果 fun 的 jitted 版本無法容納在單一裝置的記憶體中,或者為了透過跨多個裝置平行執行每個操作來加速 fun,則此功能非常有用。

裝置上的分割會根據 in_shardings 中指定的輸入分割和 out_shardings 中指定的輸出分割自動進行。這兩個引數中指定的資源必須參照網格軸,如 jax.sharding.Mesh() 上下文管理器所定義。請注意,在 pjit() 應用時的網格定義會被忽略,而傳回的函式將使用每個呼叫站點可用的網格定義。

如果 pjit() 函式的輸入尚未根據 in_shardings 正確分割,則會自動跨裝置分割。在某些情況下,確保輸入已正確預先分割可以提高效能。例如,如果將一個 pjit() 函式的輸出傳遞給另一個 pjit() 函式(或迴圈中的同一個 pjit() 函式),請確保相關的 out_shardings 符合對應的 in_shardings

注意

多進程平台: 在多進程平台(例如 TPU pods)上,pjit() 可用於跨進程在所有可用裝置上執行計算。為了實現這一點,pjit() 設計用於 SPMD Python 程式中,其中每個進程都執行相同的 Python 程式碼,以便所有進程都以相同的順序執行相同的 pjit() 函式。

在此組態中執行時,網格應包含跨所有進程的裝置。所有輸入引數都必須是全域形狀。fun 仍將在網格中的所有裝置上執行,包括來自其他進程的裝置,並且將被賦予跨多個進程散佈的資料的全域視圖,作為單一陣列。

SPMD 模型也要求相同的多進程 pjit() 函式必須在所有進程上以相同的順序執行,但它們可以與在單一進程中執行的任意操作穿插。

參數:
  • fun (Callable) – 要編譯的函式。應為純函式,因為副作用可能只執行一次。其引數和傳回值應為陣列、純量或(巢狀)標準 Python 容器(tuple/list/dict)。由 static_argnums 指示的位置引數可以是任何東西,前提是它們是可雜湊的並且定義了相等運算。靜態引數包含在編譯快取金鑰的一部分中,這就是為什麼必須定義雜湊和相等運算子的原因。

  • in_shardings

    結構與 fun 引數結構相符的 Pytree,所有實際引數都替換為資源分配規範。指定 pytree 字首(例如,一個值代替整個子樹)也是有效的,在這種情況下,葉節點會廣播到該子樹中的所有值。

    in_shardings 引數是可選的。JAX 將從輸入 jax.Array 推斷分片,如果無法推斷分片,則預設為複製輸入。

    有效的資源分配規範為

    • Sharding,它將決定如何分割值。使用此項,不需要使用網格上下文管理器。

    • None 是一種特殊情況,其語義為
      • 如果提供網格上下文管理器,則 JAX 可以自由選擇它想要的任何分片。對於 in_shardings,JAX 會將其標記為已複製,但此行為將來可能會變更。對於 out_shardings,我們將依賴 XLA GSPMD 分割器來確定輸出分片。

      • 如果提供了網格上下文管理器,則 None 將表示值將在網格的所有裝置上複製。

    • 為了向後相容性,in_shardings 仍然支援攝取 PartitionSpec。此選項只能與網格上下文管理器一起使用。

      • PartitionSpec,一個元組,其長度最多等於分割值的秩。每個元素可以是 None、網格軸或網格軸元組,並指定分配給分割值維度的資源集,該維度與其在規範中的位置相符。

    每個維度的大小必須是分配給它的資源總數的倍數。

  • out_shardings – 類似於 in_shardings,但指定函式輸出的資源分配。out_shardings 引數是可選的。如果未指定,jax.jit() 將使用 GSPMD 的分片傳播來確定如何對輸出進行分片。

  • static_argnums (int | Sequence[int] | None | None) –

    一個可選的整數或整數集合,用於指定將哪些位置引數視為靜態(編譯時常數)。僅依賴靜態引數的操作將在 Python 中(在追蹤期間)進行常數折疊,因此對應的引數值可以是任何 Python 物件。

    靜態引數應為可雜湊的,表示已實作 __hash____eq__,且為不可變的。使用這些常數的不同值呼叫 jitted 函式將觸發重新編譯。非陣列或其容器的引數必須標記為靜態。

    如果未提供 static_argnums,則不會將任何引數視為靜態。

  • static_argnames (str | Iterable[str] | None | None) – 一個可選的字串或字串集合,用於指定將哪些具名引數視為靜態(編譯時常數)。有關詳細資訊,請參閱 static_argnums 的註解。如果未提供但設定了 static_argnums,則預設值基於呼叫 inspect.signature(fun) 來尋找對應的具名引數。

  • donate_argnums (int | Sequence[int] | None | None) –

    指定哪些位置引數緩衝區「捐贈」給計算。如果您在計算完成後不再需要引數緩衝區,則捐贈引數緩衝區是安全的。在某些情況下,XLA 可以利用捐贈的緩衝區來減少執行計算所需的記憶體量,例如回收您的其中一個輸入緩衝區來儲存結果。您不應重複使用您捐贈給計算的緩衝區,如果您嘗試這樣做,JAX 將引發錯誤。預設情況下,不捐贈任何引數緩衝區。

    如果既未提供 donate_argnums 也未提供 donate_argnames,則不捐贈任何引數。如果未提供 donate_argnums 但提供了 donate_argnames,反之亦然,JAX 會使用 inspect.signature(fun) 來尋找對應於 donate_argnames 的任何位置引數(或反之亦然)。如果同時提供了 donate_argnumsdonate_argnames,則不會使用 inspect.signature,並且只會捐贈 donate_argnumsdonate_argnames 中列出的實際參數。

    有關緩衝區捐贈的更多詳細資訊,請參閱 FAQ

  • donate_argnames (str | Iterable[str] | None | None) – 一個可選的字串或字串集合,用於指定將哪些具名引數捐贈給計算。有關詳細資訊,請參閱 donate_argnums 的註解。如果未提供但設定了 donate_argnums,則預設值基於呼叫 inspect.signature(fun) 來尋找對應的具名引數。

  • keep_unused (bool) – 如果 False(預設值),JAX 確定 fun 未使用的引數可能會從產生的已編譯 XLA 可執行檔中刪除。此類引數將不會傳輸到裝置,也不會提供給底層可執行檔。如果 True,則不會修剪未使用的引數。

  • device (xc.Device | None | None) – 此引數已棄用。請在將引數傳遞給 jit 之前將其放置在您想要的裝置上。可選,jitted 函式將在其上執行的裝置。(可用裝置可以透過 jax.devices() 檢索。)預設值繼承自 XLA 的 DeviceAssignment 邏輯,通常是使用 jax.devices()[0]

  • backend (str | None | None) – 此引數已棄用。請在將引數傳遞給 jit 之前將其放置在您想要的後端上。可選,表示 XLA 後端的字串:'cpu''gpu''tpu'

  • inline (bool)

  • abstracted_axes (Any | None | None)

  • compiler_options (dict[str, Any] | None | None)

傳回:

fun 的包裝版本,設定為即時編譯,並由每個呼叫站點可用的網格自動分割。

傳回類型:

JitWrapped

例如,卷積運算子可以透過單一 pjit() 應用程式自動在任意裝置集上分割

>>> import jax
>>> import jax.numpy as jnp
>>> import numpy as np
>>> from jax.sharding import Mesh, PartitionSpec
>>> from jax.experimental.pjit import pjit
>>>
>>> x = jnp.arange(8, dtype=jnp.float32)
>>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'),
...         in_shardings=None, out_shardings=PartitionSpec('devices'))
>>> with Mesh(np.array(jax.devices()), ('devices',)):
...   print(f(x))  
[ 0.5  2.   4.   6.   8.  10.  12.  10. ]