jax.disable_jit#

jax.disable_jit(disable=True)[原始碼]#

停用 jit() 行為在其動態上下文下的上下文管理器。

為了除錯,擁有一種機制可以在動態上下文中停用所有地方的 jit() 是很有用的。請注意,這不僅停用了使用者明確使用的 jit(),還將移除 JAX 程式庫使用的任何隱含 JIT 編譯:這包括傳遞給較高階 primitives(如 scan()while_loop())的 bodycond 函式的隱含 JIT 計算、jax.numpy 函式實作中使用的 JIT,以及 API 實作中任何使用 jit() 的情況。但請注意,即使在 disable_jit 下,個別的 primitive 運算仍然會像正常的 eager op-by-op 執行一樣由 XLA 編譯。

對於資料依賴於 jitted 函式參數的值,會進行追蹤和抽象化。例如,抽象值可能是一個 ShapedArray 實例,表示具有給定形狀和 dtype 的所有可能陣列的集合,但不表示具有特定值的具體陣列。如果您在 jitted 函式中使用良性副作用操作(例如 print),您可能會注意到這些。

>>> import jax
>>>
>>> @jax.jit
... def f(x):
...   y = x * 2
...   print("Value of y is", y)
...   return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))  
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace...>
[5 7 9]

在此,y 已被 jit() 抽象化為 ShapedArray,它表示具有固定形狀和類型但具有任意值的陣列。y 的值也會被追蹤。如果我們想要在除錯時看到具體值,並且也避免追蹤器,我們可以使用 disable_jit() 上下文管理器。

>>> import jax
>>>
>>> with jax.disable_jit():
...   print(f(jax.numpy.array([1, 2, 3])))
...
Value of y is [2 4 6]
[5 7 9]
參數:

disable (bool)