偵錯執行階段值#
您是否有梯度爆炸的問題?NaN 讓您咬牙切齒嗎?只是想探查一下計算中的中間值嗎?看看以下 JAX 偵錯工具!此頁面有摘要,您可以點擊底部的「閱讀更多」連結以了解更多資訊。
目錄
使用 jax.debug
進行互動式檢查#
完整指南請見此處
摘要: 使用 jax.debug.print()
在 jax.jit
-、jax.pmap
- 和 pjit
-裝飾的函數中將值列印到 stdout,並使用 jax.debug.breakpoint()
暫停已編譯函數的執行,以檢查呼叫堆疊中的值
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.breakpoint()
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯
閱讀更多.
使用 jax.experimental.checkify
進行功能性錯誤檢查#
完整指南請見此處
摘要: Checkify 讓您可以將可 jit
化的執行階段錯誤檢查 (例如,超出邊界索引) 新增到您的 JAX 程式碼中。將 checkify.checkify
轉換與類似 assert 的 checkify.check
函數一起使用,以將執行階段檢查新增到 JAX 程式碼
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
您也可以使用 checkify 自動新增常見檢查
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
閱讀更多.
使用 JAX 的偵錯旗標擲出 Python 錯誤#
完整指南請見此處
摘要: 啟用 jax_debug_nans
旗標以自動偵測何時在 jax.jit
-編譯的程式碼中產生 NaN (但不適用於 jax.pmap
或 jax.pjit
-編譯的程式碼),並啟用 jax_disable_jit
旗標以停用 JIT 編譯,從而可以使用傳統的 Python 偵錯工具,例如 print
和 pdb
。
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
閱讀更多.