jax.jit#

jax.jit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None, compiler_options=None)[原始碼]#

設定 fun 以使用 XLA 進行即時編譯。

參數:
  • fun (Callable) –

    要進行 JIT 編譯的函式。 fun 應為純函式。

    fun 的引數和傳回值應為陣列、純量或它們的(巢狀)標準 Python 容器(tuple/list/dict)。由 static_argnums 指示的位置引數可以是任何可雜湊的型別。靜態引數會包含在編譯快取金鑰中,這就是為什麼必須定義雜湊和相等運算子的原因。JAX 會保留對 fun 的弱參考,以用作編譯快取金鑰,因此物件 fun 必須是弱參考的。

  • in_shardings – 選填,一個 Sharding 或 pytree,其具有 Sharding 葉節點和結構,該結構是傳遞給 fun 的位置引數 tuple 的樹狀前綴。如果提供,則傳遞給 fun 的位置引數必須具有與 in_shardings 相容的分片,否則會引發錯誤,並且編譯後的計算具有對應於 in_shardings 的輸入分片。如果未提供,則編譯後的計算的輸入分片會從引數分片推斷。

  • out_shardings – 選填,一個 Sharding 或 pytree,其具有 Sharding 葉節點和結構,該結構是 fun 輸出的樹狀前綴。如果提供,則其效果與將對應的 jax.lax.with_sharding_constraint 應用於 fun() 的輸出相同。

  • static_argnums (int | Sequence[int] | None | None) –

    選填,一個整數或整數集合,指定要將哪些位置引數視為靜態(追蹤和編譯時常數)。

    靜態引數應為可雜湊的,表示 __hash____eq__ 均已實作,且為不可變的。否則,它們可以是任意 Python 物件。使用這些常數的不同值呼叫 JIT 編譯函式將觸發重新編譯。非陣列式或其容器的引數必須標記為靜態。

    如果未提供 static_argnums 也未提供 static_argnames,則不會將任何引數視為靜態。如果未提供 static_argnums 但提供了 static_argnames,反之亦然,JAX 會使用 inspect.signature(fun) 來尋找對應於 static_argnames 的任何位置引數(反之亦然)。如果同時提供了 static_argnumsstatic_argnames,則不會使用 inspect.signature,並且僅將 static_argnumsstatic_argnames 中列出的實際參數視為靜態。

  • static_argnames (str | Iterable[str] | None | None) – 選填,一個字串或字串集合,指定要將哪些具名引數視為靜態(編譯時常數)。請參閱關於 static_argnums 的註解以取得詳細資訊。如果未提供但已設定 static_argnums,則預設值基於呼叫 inspect.signature(fun) 以尋找對應的具名引數。

  • donate_argnums (int | Sequence[int] | None | None) –

    選填,整數集合,用於指定哪些位置引數緩衝區可以被計算覆寫,並在呼叫端標記為已刪除。如果您在計算開始後不再需要引數緩衝區,則捐贈引數緩衝區是安全的。在某些情況下,XLA 可以利用捐贈的緩衝區來減少執行計算所需的記憶體量,例如回收您的其中一個輸入緩衝區以儲存結果。您不應重複使用您捐贈給計算的緩衝區;如果您嘗試這樣做,JAX 會引發錯誤。預設情況下,不會捐贈任何引數緩衝區。

    如果未提供 donate_argnums 也未提供 donate_argnames,則不會捐贈任何引數。如果未提供 donate_argnums 但提供了 donate_argnames,反之亦然,JAX 會使用 inspect.signature(fun) 來尋找對應於 donate_argnames 的任何位置引數(反之亦然)。如果同時提供了 donate_argnumsdonate_argnames,則不會使用 inspect.signature,並且僅將 donate_argnumsdonate_argnames 中列出的實際參數捐贈。

    有關緩衝區捐贈的更多詳細資訊,請參閱 FAQ

  • donate_argnames (str | Iterable[str] | None | None) – 選填,一個字串或字串集合,指定要將哪些具名引數捐贈給計算。請參閱關於 donate_argnums 的註解以取得詳細資訊。如果未提供但已設定 donate_argnums,則預設值基於呼叫 inspect.signature(fun) 以尋找對應的具名引數。

  • keep_unused (bool) – 選填布林值。如果為 False(預設值),JAX 判斷為 fun 未使用的引數可能會從產生的編譯 XLA 可執行檔中刪除。此類引數將不會傳輸到裝置,也不會提供給底層可執行檔。如果為 True,則不會修剪未使用的引數。

  • device (xc.Device | None | None) – 這是一個實驗性功能,API 可能會變更。選填,JIT 編譯函式將在其上執行的裝置。(可用的裝置可以透過 jax.devices() 取得。)預設值繼承自 XLA 的 DeviceAssignment 邏輯,通常是使用 jax.devices()[0]

  • backend (str | None | None) – 這是一個實驗性功能,API 可能會變更。選填,表示 XLA 後端的字串:'cpu''gpu''tpu'

  • inline (bool) – 選填布林值。指定是否應將此函式內嵌到封閉的 jaxprs 中。預設值為 False。

  • abstracted_axes (Any | None | None)

  • compiler_options (dict[str, Any] | None | None)

傳回:

fun 的包裝版本,設定為進行即時編譯。

傳回型別:

pjit.JitWrapped

範例

在以下範例中,selu 可以由 XLA 編譯成單個融合核心

>>> import jax
>>>
>>> @jax.jit
... def selu(x, alpha=1.67, lmbda=1.05):
...   return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.key(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x))  
[-0.54485  0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232  0.76827  0.59566 ]

若要在裝飾函式時傳遞諸如 static_argnames 之類的引數,常見的模式是使用 functools.partial()

>>> from functools import partial
>>>
>>> @partial(jax.jit, static_argnames=['n'])
... def g(x, n):
...   for i in range(n):
...     x = x ** 2
...   return x
>>>
>>> g(jnp.arange(4), 3)
Array([   0,    1,  256, 6561], dtype=int32)