jax.jit#
- jax.jit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None, compiler_options=None)[原始碼]#
設定
fun
以使用 XLA 進行即時編譯。- 參數:
fun (Callable) –
要進行 JIT 編譯的函式。
fun
應為純函式。fun
的引數和傳回值應為陣列、純量或它們的(巢狀)標準 Python 容器(tuple/list/dict)。由static_argnums
指示的位置引數可以是任何可雜湊的型別。靜態引數會包含在編譯快取金鑰中,這就是為什麼必須定義雜湊和相等運算子的原因。JAX 會保留對fun
的弱參考,以用作編譯快取金鑰,因此物件fun
必須是弱參考的。in_shardings – 選填,一個
Sharding
或 pytree,其具有Sharding
葉節點和結構,該結構是傳遞給fun
的位置引數 tuple 的樹狀前綴。如果提供,則傳遞給fun
的位置引數必須具有與in_shardings
相容的分片,否則會引發錯誤,並且編譯後的計算具有對應於in_shardings
的輸入分片。如果未提供,則編譯後的計算的輸入分片會從引數分片推斷。out_shardings – 選填,一個
Sharding
或 pytree,其具有Sharding
葉節點和結構,該結構是fun
輸出的樹狀前綴。如果提供,則其效果與將對應的jax.lax.with_sharding_constraint
應用於fun()
的輸出相同。static_argnums (int | Sequence[int] | None | None) –
選填,一個整數或整數集合,指定要將哪些位置引數視為靜態(追蹤和編譯時常數)。
靜態引數應為可雜湊的,表示
__hash__
和__eq__
均已實作,且為不可變的。否則,它們可以是任意 Python 物件。使用這些常數的不同值呼叫 JIT 編譯函式將觸發重新編譯。非陣列式或其容器的引數必須標記為靜態。如果未提供
static_argnums
也未提供static_argnames
,則不會將任何引數視為靜態。如果未提供static_argnums
但提供了static_argnames
,反之亦然,JAX 會使用inspect.signature(fun)
來尋找對應於static_argnames
的任何位置引數(反之亦然)。如果同時提供了static_argnums
和static_argnames
,則不會使用inspect.signature
,並且僅將static_argnums
或static_argnames
中列出的實際參數視為靜態。static_argnames (str | Iterable[str] | None | None) – 選填,一個字串或字串集合,指定要將哪些具名引數視為靜態(編譯時常數)。請參閱關於
static_argnums
的註解以取得詳細資訊。如果未提供但已設定static_argnums
,則預設值基於呼叫inspect.signature(fun)
以尋找對應的具名引數。donate_argnums (int | Sequence[int] | None | None) –
選填,整數集合,用於指定哪些位置引數緩衝區可以被計算覆寫,並在呼叫端標記為已刪除。如果您在計算開始後不再需要引數緩衝區,則捐贈引數緩衝區是安全的。在某些情況下,XLA 可以利用捐贈的緩衝區來減少執行計算所需的記憶體量,例如回收您的其中一個輸入緩衝區以儲存結果。您不應重複使用您捐贈給計算的緩衝區;如果您嘗試這樣做,JAX 會引發錯誤。預設情況下,不會捐贈任何引數緩衝區。
如果未提供
donate_argnums
也未提供donate_argnames
,則不會捐贈任何引數。如果未提供donate_argnums
但提供了donate_argnames
,反之亦然,JAX 會使用inspect.signature(fun)
來尋找對應於donate_argnames
的任何位置引數(反之亦然)。如果同時提供了donate_argnums
和donate_argnames
,則不會使用inspect.signature
,並且僅將donate_argnums
或donate_argnames
中列出的實際參數捐贈。有關緩衝區捐贈的更多詳細資訊,請參閱 FAQ。
donate_argnames (str | Iterable[str] | None | None) – 選填,一個字串或字串集合,指定要將哪些具名引數捐贈給計算。請參閱關於
donate_argnums
的註解以取得詳細資訊。如果未提供但已設定donate_argnums
,則預設值基於呼叫inspect.signature(fun)
以尋找對應的具名引數。keep_unused (bool) – 選填布林值。如果為 False(預設值),JAX 判斷為 fun 未使用的引數可能會從產生的編譯 XLA 可執行檔中刪除。此類引數將不會傳輸到裝置,也不會提供給底層可執行檔。如果為 True,則不會修剪未使用的引數。
device (xc.Device | None | None) – 這是一個實驗性功能,API 可能會變更。選填,JIT 編譯函式將在其上執行的裝置。(可用的裝置可以透過
jax.devices()
取得。)預設值繼承自 XLA 的 DeviceAssignment 邏輯,通常是使用jax.devices()[0]
。backend (str | None | None) – 這是一個實驗性功能,API 可能會變更。選填,表示 XLA 後端的字串:
'cpu'
、'gpu'
或'tpu'
。inline (bool) – 選填布林值。指定是否應將此函式內嵌到封閉的 jaxprs 中。預設值為 False。
abstracted_axes (Any | None | None)
- 傳回:
fun
的包裝版本,設定為進行即時編譯。- 傳回型別:
pjit.JitWrapped
範例
在以下範例中,
selu
可以由 XLA 編譯成單個融合核心>>> import jax >>> >>> @jax.jit ... def selu(x, alpha=1.67, lmbda=1.05): ... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha) >>> >>> key = jax.random.key(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 -0.85743 -0.78232 0.76827 0.59566 ]
若要在裝飾函式時傳遞諸如
static_argnames
之類的引數,常見的模式是使用functools.partial()
>>> from functools import partial >>> >>> @partial(jax.jit, static_argnames=['n']) ... def g(x, n): ... for i in range(n): ... x = x ** 2 ... return x >>> >>> g(jnp.arange(4), 3) Array([ 0, 1, 256, 6561], dtype=int32)