JAX 除錯標記#

JAX 提供標記和上下文管理器,可更輕鬆地捕捉錯誤。

jax_debug_nans 組態選項和上下文管理器#

摘要: 啟用 jax_debug_nans 標記以自動偵測何時在 jax.jit 編譯的程式碼中產生 NaN (但不適用於 jax.pmapjax.pjit 編譯的程式碼)。

jax_debug_nans 是一個 JAX 標記,啟用時,會在偵測到 NaN 時自動引發錯誤。它針對 JIT 編譯具有特殊處理 – 當從 JIT 函數偵測到 NaN 輸出時,該函數會以 eager 方式重新執行 (即不經編譯),並會在產生 NaN 的特定 primitive 處擲回錯誤。

用法#

如果您想要追蹤 NaN 在您的函數或梯度中發生的位置,您可以透過以下方式開啟 NaN 檢查器:

  • 設定 JAX_DEBUG_NANS=True 環境變數;

  • 在您的主要檔案頂端附近新增 jax.config.update("jax_debug_nans", True)

  • jax.config.parse_flags_with_absl() 新增至您的主要檔案,然後使用類似 --jax_debug_nans=True 的命令列標記來設定選項;

範例#

import jax
jax.config.update("jax_debug_nans", True)

def f(x, y):
  return x / y
jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception!

jax_debug_nans 的優點與限制#

優點#
  • 易於應用

  • 精確地偵測 NaN 的產生位置

  • 擲回標準 Python 例外,並與 PDB 事後除錯相容

限制#
  • jax.pmapjax.pjit 不相容

  • 以 eager 方式重新執行函數可能會很慢

  • 誤報錯誤 (例如,有意建立的 NaN)

jax_disable_jit 組態選項和上下文管理器#

摘要: 啟用 jax_disable_jit 標記以停用 JIT 編譯,從而能夠使用傳統的 Python 除錯工具,如 printpdb

jax_disable_jit 是一個 JAX 標記,啟用時,會在整個 JAX 中停用 JIT 編譯 (包括在控制流程函數中,如 jax.lax.condjax.lax.scan)。

用法#

您可以透過以下方式停用 JIT 編譯:

  • 設定 JAX_DISABLE_JIT=True 環境變數;

  • 在您的主要檔案頂端附近新增 jax.config.update("jax_disable_jit", True)

  • jax.config.parse_flags_with_absl() 新增至您的主要檔案,然後使用類似 --jax_disable_jit=True 的命令列標記來設定選項;

範例#

import jax
jax.config.update("jax_disable_jit", True)

def f(x):
  y = jnp.log(x)
  if jnp.isnan(y):
    breakpoint()
  return y
jax.jit(f)(-2.)  # ==> Enters PDB breakpoint!

jax_disable_jit 的優點與限制#

優點#
  • 易於應用

  • 能夠使用 Python 的內建 breakpointprint

  • 擲回標準 Python 例外,並與 PDB 事後除錯相容

限制#
  • jax.pmapjax.pjit 不相容

  • 在沒有 JIT 編譯的情況下執行函數可能會很慢