編譯後的列印和斷點#

jax.debug 套件提供了一些有用的工具,用於檢查編譯後函式內的值。

使用 jax.debug.print 和其他除錯回呼進行除錯#

摘要: 使用 jax.debug.print() 在編譯後的函式 (例如 jax.jitjax.pmap 修飾的函式) 中,將追蹤的陣列值列印到 stdout

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y

f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯

對於某些轉換,例如 jax.gradjax.vmap,您可以使用 Python 的內建 print 函式來列印數值。但是 print 無法與 jax.jitjax.pmap 搭配使用,因為這些轉換會延遲數值評估。因此請改用 jax.debug.print

在語意上,jax.debug.print 大致等同於以下 Python 函式

def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  print(fmt.format(*args, **kwargs))

但它可以由 JAX staged out 和轉換。請參閱 API 參考 以取得更多詳細資訊。

請注意,fmt 不能是 f-string,因為 f-string 會立即格式化,而對於 jax.debug.print,我們希望延遲格式化直到稍後。

何時使用「debug」列印?#

您應該在 JAX 轉換 (例如 jitvmap 和其他轉換) 中,針對動態 (即追蹤的) 陣列值使用 jax.debug.print。對於靜態值 (例如陣列形狀或 dtype) 的列印,您可以使用一般的 Python print 陳述式。

為何使用「debug」列印?#

以除錯之名,jax.debug.print 可以揭示關於計算如何評估的資訊

xs = jnp.arange(3.)

def f(x):
  jax.debug.print("x: {}", x)
  y = jnp.sin(x)
  jax.debug.print("y: {}", y)
  return y
jax.vmap(f)(xs)
# Prints: x: 0.0
#         x: 1.0
#         x: 2.0
#         y: 0.0
#         y: 0.841471
#         y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
#         y: 0.0
#         x: 1.0
#         y: 0.841471
#         x: 2.0
#         y: 0.9092974

請注意,列印的結果順序不同!

透過揭示這些內部運作方式,jax.debug.print 的輸出不遵守 JAX 通常的語意保證,例如 jax.vmap(f)(xs)jax.lax.map(f, xs) 計算相同的東西 (以不同的方式)。然而,這些評估順序的細節正是我們在除錯時可能想看到的!

因此,請將 jax.debug.print 用於除錯,而不是在語意保證很重要時使用。

jax.debug.print 的更多範例#

除了上面使用 jitvmap 的範例之外,這裡還有一些範例供您參考。

jax.pmap 下列印#

當使用 jax.pmap 時,jax.debug.print 的順序可能會重新排列!

xs = jnp.arange(2.)

def f(x):
  jax.debug.print("x: {}", x)
  return x
jax.pmap(f)(xs)
# Prints: x: 1.0
#         x: 0.0
# OR
# Prints: x: 1.0
#         x: 0.0

jax.grad 下列印#

jax.grad 下,jax.debug.print 只會在前向傳遞中列印

def f(x):
  jax.debug.print("x: {}", x)
  return x * 2.

jax.grad(f)(1.)
# Prints: x: 1.0

此行為類似於 Python 的內建 printjax.grad 下的運作方式。但是透過在此處使用 jax.debug.print,即使呼叫者套用 jax.jit,行為仍然相同。

若要在反向傳遞中列印,只需使用 jax.custom_vjp

@jax.custom_vjp
def print_grad(x):
  return x

def print_grad_fwd(x):
  return x, None

def print_grad_bwd(_, x_grad):
  jax.debug.print("x_grad: {}", x_grad)
  return (x_grad,)

print_grad.defvjp(print_grad_fwd, print_grad_bwd)


def f(x):
  x = print_grad(x)
  return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0

在其他轉換中列印#

jax.debug.print 也適用於其他轉換,例如 pjit

使用 jax.debug.callback 進行更多控制#

實際上,jax.debug.printjax.debug.callback 的精簡便利包裝函式,可以直接使用它來更精細地控制字串格式設定,甚至是輸出的種類。

在語意上,jax.debug.callback 大致等同於以下 Python 函式

def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  fun(*args, **kwargs)
  return None

jax.debug.print 相同,這些回呼應僅用於除錯輸出,例如列印或繪圖。列印和繪圖相當無害,但如果您將其用於其他用途,則其行為在轉換下可能會讓您感到驚訝。例如,使用 jax.debug.callback 進行計時操作並不安全,因為回呼可能會重新排序且是非同步的 (請參閱下文)。

尖銳之處#

與大多數 JAX API 一樣,如果您不小心,jax.debug.print 也可能會讓您受傷。

列印結果的順序#

當對 jax.debug.print 的不同呼叫涉及彼此不相依的引數時,它們可能會在 staged out 時重新排序,例如透過 jax.jit

@jax.jit
def f(x, y):
  jax.debug.print("x: {}", x)
  jax.debug.print("y: {}", y)
  return x + y

f(2., 3.)
# Prints: x: 2.0
#         y: 3.0
# OR
# Prints: y: 3.0
#         x: 2.0

為什麼?在底層,編譯器會取得 staged-out 計算的功能表示法,其中 Python 函式的命令式順序會遺失,而只剩下資料相依性。此變更對於使用功能純程式碼的使用者是不可見的,但在存在列印等副作用的情況下,它會變得明顯。

若要保留 jax.debug.print 原始順序 (如同在 Python 函式中所撰寫),您可以使用 jax.debug.print(..., ordered=True),這將確保保留列印的相對順序。但是,在 jax.pmap 和其他涉及平行處理的 JAX 轉換下使用 ordered=True 會引發錯誤,因為在平行執行下無法保證順序。

非同步回呼#

根據後端而定,jax.debug.print 可能會非同步發生,即不在您的主程式執行緒中發生。這表示即使在您的 JAX 函式傳回值之後,值仍可能會列印到您的螢幕上。

@jax.jit
def f(x):
  jax.debug.print("x: {}", x)
  return x
f(2.).block_until_ready()
# <do something else>
# Prints: x: 2.

若要封鎖函式中的 jax.debug.print,您可以呼叫 jax.effects_barrier(),這將等待直到函式中任何剩餘的副作用也完成為止

@jax.jit
def f(x):
  jax.debug.print("x: {}", x)
  return x
f(2.).block_until_ready()
jax.effects_barrier()
# Prints: x: 2.
# <do something else>

效能影響#

不必要的實體化#

雖然 jax.debug.print 的設計目的是要將效能影響降到最低,但它可能會干擾編譯器最佳化,並可能影響 JAX 程式的記憶體配置。

def f(w, b, x):
  logits = w.dot(x) + b
  jax.debug.print("logits: {}", logits)
  return jax.nn.relu(logits)

在此範例中,我們正在線性層和啟動函式之間列印中間值。XLA 等編譯器可以執行融合最佳化,這可能會避免在記憶體中實體化 logits。但是當我們在 logits 上使用 jax.debug.print 時,我們正在強制將這些中間值實體化,這可能會減慢程式速度並增加記憶體使用量。

此外,當將 jax.debug.printjax.pjit 搭配使用時,會發生全域同步處理,這會將值實體化在單一裝置上。

回呼額外負擔#

jax.debug.print 本質上會產生加速器與其主機之間的通訊。底層機制因後端而異 (例如 GPU 與 TPU),但在所有情況下,我們都需要將列印的值從裝置複製到主機。在 CPU 情況下,此額外負擔較小。

此外,當將 jax.debug.printjax.pjit 搭配使用時,會發生全域同步處理,這會增加一些額外負擔。

jax.debug.print 的優點和限制#

優點#

  • 列印除錯既簡單又直覺

  • jax.debug.callback 可以用於其他無害的副作用

限制#

  • 新增列印陳述式是手動過程

  • 可能會有效能影響

使用 jax.debug.breakpoint() 進行互動式檢查#

摘要: 使用 jax.debug.breakpoint() 暫停 JAX 程式的執行以檢查值

@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution!

JAX debugger

jax.debug.breakpoint() 實際上只是 jax.debug.callback(...) 的應用,它會擷取關於呼叫堆疊的資訊。因此,它具有與 jax.debug.print 相同的轉換行為 (例如,vmap-ing jax.debug.breakpoint() 會在映射軸上展開它)。

用法#

在編譯後的 JAX 函式中呼叫 jax.debug.breakpoint() 將在程式到達斷點時暫停您的程式。您會看到類似 pdb 的提示,讓您可以檢查呼叫堆疊中的值。與 pdb 不同,您將無法逐步執行,但您可以繼續執行。

除錯器命令

  • help - 列印可用的命令

  • p - 評估運算式並列印其結果

  • pp - 評估運算式並以美觀方式列印其結果

  • u(p) - 向上移動堆疊框架

  • d(own) - 向下移動堆疊框架

  • w(here)/bt - 列印回溯追蹤

  • l(ist) - 列印程式碼上下文

  • c(ont(inue)) - 繼續執行程式

  • q(uit)/exit - 結束程式 (在 TPU 上無法運作)

範例#

jax.lax.cond 一起使用#

當與 jax.lax.cond 結合使用時,除錯器可以成為偵測 naninf 的有用工具。

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  lax.cond(is_finite, true_fn, false_fn, x)

@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z
f(2., 0.) # ==> Pauses during execution!

尖銳之處#

由於 jax.debug.breakpoint 只是 jax.debug.callback 的應用,因此它具有與 jax.debug.print 相同的尖銳之處,以及一些額外的注意事項

  • jax.debug.breakpointjax.debug.print 實體化更多的中間值,因為它會強制實體化呼叫堆疊中的所有值

  • jax.debug.breakpointjax.debug.print 具有更多的執行階段額外負擔,因為它必須可能將 JAX 程式中的所有中間值從裝置複製到主機。

jax.debug.breakpoint() 的優點和限制#

優點#

  • 簡單、直覺且 (在某種程度上) 標準

  • 可以同時檢查呼叫堆疊上下中的許多值

限制#

  • 可能需要使用許多斷點來精確找出錯誤來源

  • 實體化許多中間值