編譯後的列印和斷點#
jax.debug
套件提供了一些有用的工具,用於檢查編譯後函式內的值。
使用 jax.debug.print
和其他除錯回呼進行除錯#
摘要: 使用 jax.debug.print()
在編譯後的函式 (例如 jax.jit
或 jax.pmap
修飾的函式) 中,將追蹤的陣列值列印到 stdout
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯
對於某些轉換,例如 jax.grad
和 jax.vmap
,您可以使用 Python 的內建 print
函式來列印數值。但是 print
無法與 jax.jit
或 jax.pmap
搭配使用,因為這些轉換會延遲數值評估。因此請改用 jax.debug.print
!
在語意上,jax.debug.print
大致等同於以下 Python 函式
def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
print(fmt.format(*args, **kwargs))
但它可以由 JAX staged out 和轉換。請參閱 API 參考
以取得更多詳細資訊。
請注意,fmt
不能是 f-string,因為 f-string 會立即格式化,而對於 jax.debug.print
,我們希望延遲格式化直到稍後。
何時使用「debug」列印?#
您應該在 JAX 轉換 (例如 jit
、vmap
和其他轉換) 中,針對動態 (即追蹤的) 陣列值使用 jax.debug.print
。對於靜態值 (例如陣列形狀或 dtype) 的列印,您可以使用一般的 Python print
陳述式。
為何使用「debug」列印?#
以除錯之名,jax.debug.print
可以揭示關於計算如何評估的資訊
xs = jnp.arange(3.)
def f(x):
jax.debug.print("x: {}", x)
y = jnp.sin(x)
jax.debug.print("y: {}", y)
return y
jax.vmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# x: 2.0
# y: 0.0
# y: 0.841471
# y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
# y: 0.0
# x: 1.0
# y: 0.841471
# x: 2.0
# y: 0.9092974
請注意,列印的結果順序不同!
透過揭示這些內部運作方式,jax.debug.print
的輸出不遵守 JAX 通常的語意保證,例如 jax.vmap(f)(xs)
和 jax.lax.map(f, xs)
計算相同的東西 (以不同的方式)。然而,這些評估順序的細節正是我們在除錯時可能想看到的!
因此,請將 jax.debug.print
用於除錯,而不是在語意保證很重要時使用。
jax.debug.print
的更多範例#
除了上面使用 jit
和 vmap
的範例之外,這裡還有一些範例供您參考。
在 jax.pmap
下列印#
當使用 jax.pmap
時,jax.debug.print
的順序可能會重新排列!
xs = jnp.arange(2.)
def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 1.0
# x: 0.0
# OR
# Prints: x: 1.0
# x: 0.0
在 jax.grad
下列印#
在 jax.grad
下,jax.debug.print
只會在前向傳遞中列印
def f(x):
jax.debug.print("x: {}", x)
return x * 2.
jax.grad(f)(1.)
# Prints: x: 1.0
此行為類似於 Python 的內建 print
在 jax.grad
下的運作方式。但是透過在此處使用 jax.debug.print
,即使呼叫者套用 jax.jit
,行為仍然相同。
若要在反向傳遞中列印,只需使用 jax.custom_vjp
@jax.custom_vjp
def print_grad(x):
return x
def print_grad_fwd(x):
return x, None
def print_grad_bwd(_, x_grad):
jax.debug.print("x_grad: {}", x_grad)
return (x_grad,)
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
def f(x):
x = print_grad(x)
return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0
在其他轉換中列印#
jax.debug.print
也適用於其他轉換,例如 pjit
。
使用 jax.debug.callback
進行更多控制#
實際上,jax.debug.print
是 jax.debug.callback
的精簡便利包裝函式,可以直接使用它來更精細地控制字串格式設定,甚至是輸出的種類。
在語意上,jax.debug.callback
大致等同於以下 Python 函式
def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
fun(*args, **kwargs)
return None
與 jax.debug.print
相同,這些回呼應僅用於除錯輸出,例如列印或繪圖。列印和繪圖相當無害,但如果您將其用於其他用途,則其行為在轉換下可能會讓您感到驚訝。例如,使用 jax.debug.callback
進行計時操作並不安全,因為回呼可能會重新排序且是非同步的 (請參閱下文)。
jax.debug.print
的優點和限制#
優點#
列印除錯既簡單又直覺
jax.debug.callback
可以用於其他無害的副作用
限制#
新增列印陳述式是手動過程
可能會有效能影響
使用 jax.debug.breakpoint()
進行互動式檢查#
摘要: 使用 jax.debug.breakpoint()
暫停 JAX 程式的執行以檢查值
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution!
jax.debug.breakpoint()
實際上只是 jax.debug.callback(...)
的應用,它會擷取關於呼叫堆疊的資訊。因此,它具有與 jax.debug.print
相同的轉換行為 (例如,vmap
-ing jax.debug.breakpoint()
會在映射軸上展開它)。
用法#
在編譯後的 JAX 函式中呼叫 jax.debug.breakpoint()
將在程式到達斷點時暫停您的程式。您會看到類似 pdb
的提示,讓您可以檢查呼叫堆疊中的值。與 pdb
不同,您將無法逐步執行,但您可以繼續執行。
除錯器命令
help
- 列印可用的命令p
- 評估運算式並列印其結果pp
- 評估運算式並以美觀方式列印其結果u(p)
- 向上移動堆疊框架d(own)
- 向下移動堆疊框架w(here)/bt
- 列印回溯追蹤l(ist)
- 列印程式碼上下文c(ont(inue))
- 繼續執行程式q(uit)/exit
- 結束程式 (在 TPU 上無法運作)
範例#
與 jax.lax.cond
一起使用#
當與 jax.lax.cond
結合使用時,除錯器可以成為偵測 nan
或 inf
的有用工具。
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 0.) # ==> Pauses during execution!
尖銳之處#
由於 jax.debug.breakpoint
只是 jax.debug.callback
的應用,因此它具有與 jax.debug.print
相同的尖銳之處,以及一些額外的注意事項
jax.debug.breakpoint
比jax.debug.print
實體化更多的中間值,因為它會強制實體化呼叫堆疊中的所有值jax.debug.breakpoint
比jax.debug.print
具有更多的執行階段額外負擔,因為它必須可能將 JAX 程式中的所有中間值從裝置複製到主機。
jax.debug.breakpoint()
的優點和限制#
優點#
簡單、直覺且 (在某種程度上) 標準
可以同時檢查呼叫堆疊上下中的許多值
限制#
可能需要使用許多斷點來精確找出錯誤來源
實體化許多中間值