JAX 除錯標記#
JAX 提供標記和上下文管理器,可更輕鬆地捕捉錯誤。
jax_debug_nans
組態選項和上下文管理器#
摘要: 啟用 jax_debug_nans
標記以自動偵測何時在 jax.jit
編譯的程式碼中產生 NaN (但不適用於 jax.pmap
或 jax.pjit
編譯的程式碼)。
jax_debug_nans
是一個 JAX 標記,啟用時,會在偵測到 NaN 時自動引發錯誤。它針對 JIT 編譯具有特殊處理 – 當從 JIT 函數偵測到 NaN 輸出時,該函數會以 eager 方式重新執行 (即不經編譯),並會在產生 NaN 的特定 primitive 處擲回錯誤。
用法#
如果您想要追蹤 NaN 在您的函數或梯度中發生的位置,您可以透過以下方式開啟 NaN 檢查器:
設定
JAX_DEBUG_NANS=True
環境變數;在您的主要檔案頂端附近新增
jax.config.update("jax_debug_nans", True)
;將
jax.config.parse_flags_with_absl()
新增至您的主要檔案,然後使用類似--jax_debug_nans=True
的命令列標記來設定選項;
範例#
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
jax_debug_nans
的優點與限制#
優點#
易於應用
精確地偵測 NaN 的產生位置
擲回標準 Python 例外,並與 PDB 事後除錯相容
限制#
與
jax.pmap
或jax.pjit
不相容以 eager 方式重新執行函數可能會很慢
誤報錯誤 (例如,有意建立的 NaN)
jax_disable_jit
組態選項和上下文管理器#
摘要: 啟用 jax_disable_jit
標記以停用 JIT 編譯,從而能夠使用傳統的 Python 除錯工具,如 print
和 pdb
jax_disable_jit
是一個 JAX 標記,啟用時,會在整個 JAX 中停用 JIT 編譯 (包括在控制流程函數中,如 jax.lax.cond
和 jax.lax.scan
)。
用法#
您可以透過以下方式停用 JIT 編譯:
設定
JAX_DISABLE_JIT=True
環境變數;在您的主要檔案頂端附近新增
jax.config.update("jax_disable_jit", True)
;將
jax.config.parse_flags_with_absl()
新增至您的主要檔案,然後使用類似--jax_disable_jit=True
的命令列標記來設定選項;
範例#
import jax
jax.config.update("jax_disable_jit", True)
def f(x):
y = jnp.log(x)
if jnp.isnan(y):
breakpoint()
return y
jax.jit(f)(-2.) # ==> Enters PDB breakpoint!
jax_disable_jit
的優點與限制#
優點#
易於應用
能夠使用 Python 的內建
breakpoint
和print
擲回標準 Python 例外,並與 PDB 事後除錯相容
限制#
與
jax.pmap
或jax.pjit
不相容在沒有 JIT 編譯的情況下執行函數可能會很慢