jax.experimental.pjit
模組#
API#
- jax.experimental.pjit.pjit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None, compiler_options=None)[原始碼]#
使
fun
經過編譯,並自動跨多個裝置進行分割。注意:此函式現在等同於 jax.jit,請改用該函式。傳回的函式具有與
fun
等效的語義,但會編譯成跨多個裝置(例如多個 GPU 或多個 TPU 核心)執行的 XLA 計算。如果fun
的 jitted 版本無法容納在單一裝置的記憶體中,或者為了透過跨多個裝置平行執行每個操作來加速fun
,則此功能非常有用。裝置上的分割會根據
in_shardings
中指定的輸入分割和out_shardings
中指定的輸出分割自動進行。這兩個引數中指定的資源必須參照網格軸,如jax.sharding.Mesh()
上下文管理器所定義。請注意,在pjit()
應用時的網格定義會被忽略,而傳回的函式將使用每個呼叫站點可用的網格定義。如果
pjit()
函式的輸入尚未根據in_shardings
正確分割,則會自動跨裝置分割。在某些情況下,確保輸入已正確預先分割可以提高效能。例如,如果將一個pjit()
函式的輸出傳遞給另一個pjit()
函式(或迴圈中的同一個pjit()
函式),請確保相關的out_shardings
符合對應的in_shardings
。注意
多進程平台: 在多進程平台(例如 TPU pods)上,
pjit()
可用於跨進程在所有可用裝置上執行計算。為了實現這一點,pjit()
設計用於 SPMD Python 程式中,其中每個進程都執行相同的 Python 程式碼,以便所有進程都以相同的順序執行相同的pjit()
函式。在此組態中執行時,網格應包含跨所有進程的裝置。所有輸入引數都必須是全域形狀。
fun
仍將在網格中的所有裝置上執行,包括來自其他進程的裝置,並且將被賦予跨多個進程散佈的資料的全域視圖,作為單一陣列。SPMD 模型也要求相同的多進程
pjit()
函式必須在所有進程上以相同的順序執行,但它們可以與在單一進程中執行的任意操作穿插。- 參數:
fun (Callable) – 要編譯的函式。應為純函式,因為副作用可能只執行一次。其引數和傳回值應為陣列、純量或(巢狀)標準 Python 容器(tuple/list/dict)。由
static_argnums
指示的位置引數可以是任何東西,前提是它們是可雜湊的並且定義了相等運算。靜態引數包含在編譯快取金鑰的一部分中,這就是為什麼必須定義雜湊和相等運算子的原因。in_shardings –
結構與
fun
引數結構相符的 Pytree,所有實際引數都替換為資源分配規範。指定 pytree 字首(例如,一個值代替整個子樹)也是有效的,在這種情況下,葉節點會廣播到該子樹中的所有值。in_shardings
引數是可選的。JAX 將從輸入jax.Array
推斷分片,如果無法推斷分片,則預設為複製輸入。有效的資源分配規範為
Sharding
,它將決定如何分割值。使用此項,不需要使用網格上下文管理器。None
是一種特殊情況,其語義為如果未提供網格上下文管理器,則 JAX 可以自由選擇它想要的任何分片。對於 in_shardings,JAX 會將其標記為已複製,但此行為將來可能會變更。對於 out_shardings,我們將依賴 XLA GSPMD 分割器來確定輸出分片。
如果提供了網格上下文管理器,則 None 將表示值將在網格的所有裝置上複製。
為了向後相容性,in_shardings 仍然支援攝取
PartitionSpec
。此選項只能與網格上下文管理器一起使用。PartitionSpec
,一個元組,其長度最多等於分割值的秩。每個元素可以是None
、網格軸或網格軸元組,並指定分配給分割值維度的資源集,該維度與其在規範中的位置相符。
每個維度的大小必須是分配給它的資源總數的倍數。
out_shardings – 類似於
in_shardings
,但指定函式輸出的資源分配。out_shardings
引數是可選的。如果未指定,jax.jit()
將使用 GSPMD 的分片傳播來確定如何對輸出進行分片。static_argnums (int | Sequence[int] | None | None) –
一個可選的整數或整數集合,用於指定將哪些位置引數視為靜態(編譯時常數)。僅依賴靜態引數的操作將在 Python 中(在追蹤期間)進行常數折疊,因此對應的引數值可以是任何 Python 物件。
靜態引數應為可雜湊的,表示已實作
__hash__
和__eq__
,且為不可變的。使用這些常數的不同值呼叫 jitted 函式將觸發重新編譯。非陣列或其容器的引數必須標記為靜態。如果未提供
static_argnums
,則不會將任何引數視為靜態。static_argnames (str | Iterable[str] | None | None) – 一個可選的字串或字串集合,用於指定將哪些具名引數視為靜態(編譯時常數)。有關詳細資訊,請參閱
static_argnums
的註解。如果未提供但設定了static_argnums
,則預設值基於呼叫inspect.signature(fun)
來尋找對應的具名引數。donate_argnums (int | Sequence[int] | None | None) –
指定哪些位置引數緩衝區「捐贈」給計算。如果您在計算完成後不再需要引數緩衝區,則捐贈引數緩衝區是安全的。在某些情況下,XLA 可以利用捐贈的緩衝區來減少執行計算所需的記憶體量,例如回收您的其中一個輸入緩衝區來儲存結果。您不應重複使用您捐贈給計算的緩衝區,如果您嘗試這樣做,JAX 將引發錯誤。預設情況下,不捐贈任何引數緩衝區。
如果既未提供
donate_argnums
也未提供donate_argnames
,則不捐贈任何引數。如果未提供donate_argnums
但提供了donate_argnames
,反之亦然,JAX 會使用inspect.signature(fun)
來尋找對應於donate_argnames
的任何位置引數(或反之亦然)。如果同時提供了donate_argnums
和donate_argnames
,則不會使用inspect.signature
,並且只會捐贈donate_argnums
或donate_argnames
中列出的實際參數。有關緩衝區捐贈的更多詳細資訊,請參閱 FAQ。
donate_argnames (str | Iterable[str] | None | None) – 一個可選的字串或字串集合,用於指定將哪些具名引數捐贈給計算。有關詳細資訊,請參閱
donate_argnums
的註解。如果未提供但設定了donate_argnums
,則預設值基於呼叫inspect.signature(fun)
來尋找對應的具名引數。keep_unused (bool) – 如果 False(預設值),JAX 確定 fun 未使用的引數可能會從產生的已編譯 XLA 可執行檔中刪除。此類引數將不會傳輸到裝置,也不會提供給底層可執行檔。如果 True,則不會修剪未使用的引數。
device (xc.Device | None | None) – 此引數已棄用。請在將引數傳遞給 jit 之前將其放置在您想要的裝置上。可選,jitted 函式將在其上執行的裝置。(可用裝置可以透過
jax.devices()
檢索。)預設值繼承自 XLA 的 DeviceAssignment 邏輯,通常是使用jax.devices()[0]
。backend (str | None | None) – 此引數已棄用。請在將引數傳遞給 jit 之前將其放置在您想要的後端上。可選,表示 XLA 後端的字串:
'cpu'
、'gpu'
或'tpu'
。inline (bool)
abstracted_axes (Any | None | None)
compiler_options (dict[str, Any] | None | None)
- 傳回:
fun
的包裝版本,設定為即時編譯,並由每個呼叫站點可用的網格自動分割。- 傳回類型:
JitWrapped
例如,卷積運算子可以透過單一
pjit()
應用程式自動在任意裝置集上分割>>> import jax >>> import jax.numpy as jnp >>> import numpy as np >>> from jax.sharding import Mesh, PartitionSpec >>> from jax.experimental.pjit import pjit >>> >>> x = jnp.arange(8, dtype=jnp.float32) >>> f = pjit(lambda x: jax.numpy.convolve(x, jnp.asarray([0.5, 1.0, 0.5]), 'same'), ... in_shardings=None, out_shardings=PartitionSpec('devices')) >>> with Mesh(np.array(jax.devices()), ('devices',)): ... print(f(x)) [ 0.5 2. 4. 6. 8. 10. 12. 10. ]