偵錯簡介#
本節向您介紹一組內建的 JAX 偵錯方法 — jax.debug.print()
、jax.debug.breakpoint()
和 jax.debug.callback()
— 您可以將它們與各種 JAX 轉換一起使用。
讓我們先從 jax.debug.print()
開始。
jax.debug.print
用於簡單檢查#
以下是一個經驗法則
對於使用
jax.jit()
、jax.vmap()
等轉換追蹤 (動態) 陣列值,請使用jax.debug.print()
。對於靜態值,例如 dtype 和陣列形狀,請使用 Python
print()
。
從 即時編譯 回想一下,當使用 jax.jit()
轉換函式時,Python 程式碼會使用抽象追蹤器來代替您的陣列執行。因此,Python print()
函式只會列印此追蹤器值
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
print("print(x) ->", x)
y = jnp.sin(x)
print("print(y) ->", y)
return y
result = f(2.)
print(x) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>
print(y) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>
Python 的 print
在執行階段值存在之前,於追蹤時執行。如果您想要列印實際的執行階段值,可以使用 jax.debug.print()
@jax.jit
def f(x):
jax.debug.print("jax.debug.print(x) -> {x}", x=x)
y = jnp.sin(x)
jax.debug.print("jax.debug.print(y) -> {y}", y=y)
return y
result = f(2.)
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314
同樣地,在 jax.vmap()
中,使用 Python 的 print
只會列印追蹤器;若要列印正在映射的值,請使用 jax.debug.print()
def f(x):
jax.debug.print("jax.debug.print(x) -> {}", x)
y = jnp.sin(x)
jax.debug.print("jax.debug.print(y) -> {}", y)
return y
xs = jnp.arange(3.)
result = jax.vmap(f)(xs)
jax.debug.print(x) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(y) -> 0.9092974066734314
以下是使用 jax.lax.map()
的結果,這是一個循序映射而不是向量化
result = jax.lax.map(f, xs)
jax.debug.print(x) -> 0.0
jax.debug.print(y) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314
請注意,順序不同,因為 jax.vmap()
和 jax.lax.map()
以不同的方式計算相同的結果。偵錯時,評估順序的詳細資訊正是您可能需要檢查的內容。
以下是使用 jax.grad()
的範例,其中 jax.debug.print()
只會列印正向傳遞。在這種情況下,行為類似於 Python 的 print()
,但如果您在呼叫期間應用 jax.jit()
,則行為是一致的。
def f(x):
jax.debug.print("jax.debug.print(x) -> {}", x)
return x ** 2
result = jax.grad(f)(1.)
jax.debug.print(x) -> 1.0
有時,當引數彼此不相依時,使用 JAX 轉換 staged out 時,呼叫 jax.debug.print()
可能會以不同的順序列印它們。如果您需要原始順序,例如先 x: ...
然後 y: ...
,請新增 ordered=True
參數。
例如
@jax.jit
def f(x, y):
jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
return x + y
f(1, 2)
jax.debug.print(x) -> 1
jax.debug.print(y) -> 2
Array(3, dtype=int32, weak_type=True)
若要深入瞭解 jax.debug.print()
及其尖銳之處,請參閱 進階偵錯。
jax.debug.breakpoint
用於類似 pdb
的偵錯#
摘要: 使用 jax.debug.breakpoint()
暫停 JAX 程式的執行,以檢查值。
若要在偵錯期間暫停編譯後的 JAX 程式,您可以使用 jax.debug.breakpoint()
。提示符號類似於 Python pdb
,它允許您檢查呼叫堆疊中的值。實際上,jax.debug.breakpoint()
是 jax.debug.callback()
的應用,它會擷取有關呼叫堆疊的資訊。
若要列印 breakpoint
偵錯會話期間的所有可用命令,請使用 help
命令。(完整的偵錯器命令、尖銳之處、其優點和限制在 進階偵錯 中涵蓋。)
以下是一個偵錯器會話可能的外觀範例
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution
對於依賴值的斷點,您可以使用執行階段條件,例如 jax.lax.cond()
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
jax.lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 1.) # ==> No breakpoint
Array(2., dtype=float32, weak_type=True)
f(2., 0.) # ==> Pauses during execution
jax.debug.callback
用於偵錯期間的更多控制#
jax.debug.print()
和 jax.debug.breakpoint()
都是使用更彈性的 jax.debug.callback()
實作的,它可以更有效地控制透過 Python 回調執行的主機端邏輯。它與 jax.jit()
、jax.vmap()
、jax.grad()
和其他轉換相容 (如需更多資訊,請參閱 回調的類型 表格在 外部回調 中)。
例如
import logging
def log_value(x):
logging.warning(f'Logged value: {x}')
@jax.jit
def f(x):
jax.debug.callback(log_value, x)
return x
f(1.0);
WARNING:root:Logged value: 1.0
此回調與其他轉換相容,包括 jax.vmap()
和 jax.grad()
x = jnp.arange(5.0)
jax.vmap(f)(x);
WARNING:root:Logged value: 0.0
WARNING:root:Logged value: 1.0
WARNING:root:Logged value: 2.0
WARNING:root:Logged value: 3.0
WARNING:root:Logged value: 4.0
jax.grad(f)(1.0);
WARNING:root:Logged value: 1.0
這可以使 jax.debug.callback()
對於通用偵錯非常有用。
您可以在 外部回調 中瞭解更多關於 jax.debug.callback()
和其他種類 JAX 回調的資訊。
下一步#
查看 進階偵錯,以瞭解更多關於 JAX 中的偵錯資訊。