外部回調#
本教學概述如何使用各種回調函式,這些函式允許 JAX 執行階段在主機上執行 Python 程式碼。JAX 回調的範例包括 jax.pure_callback()
、jax.experimental.io_callback()
和 jax.debug.callback()
。即使在 JAX 轉換下執行時,包括 jit()
、vmap()
、grad()
,您也可以使用它們。
為什麼需要回調?#
回調常式是一種在執行階段執行程式碼的主機端方式。舉一個簡單的例子,假設您想在計算過程中列印某個變數的值。使用簡單的 Python print()
陳述式,看起來像這樣
import jax
@jax.jit
def f(x):
y = x + 1
print("intermediate value: {}".format(y))
return y * 2
result = f(2)
intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>
列印出來的不是執行階段值,而是追蹤時間抽象值(如果您不熟悉 JAX 中的追蹤,可以在 追蹤 中找到很好的入門介紹)。
若要在執行階段列印值,您需要一個回調,例如 jax.debug.print()
(您可以在 除錯簡介 中瞭解更多關於除錯的資訊)
@jax.jit
def f(x):
y = x + 1
jax.debug.print("intermediate value: {}", y)
return y * 2
result = f(2)
intermediate value: 3
其運作方式是將 y
的執行階段值作為 CPU jax.Array
傳回主機進程,主機可以在其中列印它。
回調的種類#
在舊版本的 JAX 中,只有一種可用的回調,實作在 jax.experimental.host_callback()
中。host_callback
常式有一些缺陷,現在已被棄用,轉而使用為不同情況設計的幾種回調
jax.pure_callback()
:適用於純函式:即沒有副作用的函式。jax.experimental.io_callback()
:適用於不純函式:例如,讀取或寫入資料到磁碟的函式。jax.debug.callback()
:適用於應反映編譯器執行行為的函式。
(您先前使用的 jax.debug.print()
函式是 jax.debug.callback()
的包裝器)。
從使用者的角度來看,這三種回調種類主要通過它們允許的轉換和編譯器最佳化來區分。
回調函式 |
支援傳回值 |
|
|
|
|
保證執行 |
---|---|---|---|---|---|---|
✅ |
✅ |
✅ |
❌¹ |
✅ |
❌ |
|
✅ |
✅ |
✅/❌² |
❌ |
✅³ |
✅ |
|
❌ |
✅ |
✅ |
✅ |
✅ |
❌ |
¹ jax.pure_callback
可以與 custom_jvp
搭配使用,使其與自動微分相容
² 僅當 ordered=False
時,jax.experimental.io_callback
才與 vmap
相容。
³ 請注意,io_callback
的 scan
/while_loop
的 vmap
具有複雜的語義,其行為可能會在未來的版本中變更。
探索 pure_callback
#
當您想要主機端執行純函式時,jax.pure_callback()
通常是您應該使用的回調函式:即沒有副作用的函式 (例如列印值、從磁碟讀取資料、更新全域狀態等)。
您傳遞給 jax.pure_callback()
的函式實際上不需要是純函式,但 JAX 的轉換和高階函式會假定它是純函式,這表示它可能會被靜默地省略或多次調用。
import jax
import jax.numpy as jnp
import numpy as np
def f_host(x):
# call a numpy (not jax.numpy) operation:
return np.sin(x).astype(x.dtype)
def f(x):
result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jax.pure_callback(f_host, result_shape, x)
x = jnp.arange(5.0)
f(x)
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
由於 pure_callback
可以被省略或複製,因此它可以與 jit
和 vmap
等轉換以及 scan
和 while_loop
等高階基本運算開箱即用相容:
jax.jit(f)(x)
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
jax.vmap(f)(x)
/tmp/ipykernel_889/3691550925.py:11: DeprecationWarning: The default behavior of pure_callback under vmap will soon change. Currently, the default behavior is to generate a sequential vmap (i.e. a loop), but in the future the default will be to raise an error. To keep the current default, set vmap_method='sequential'.
return jax.pure_callback(f_host, result_shape, x)
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
def body_fun(_, x):
return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
然而,由於 JAX 無法內省回調的內容,pure_callback
具有未定義的自動微分語義
jax.grad(f)(x)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.
有關搭配 jax.custom_jvp()
使用 pure_callback
的範例,請參閱下方的範例:搭配 custom_jvp
使用 pure_callback
。
依設計,傳遞給 pure_callback
的函式被視為沒有副作用:這樣做的一個後果是,如果沒有使用函式的輸出,編譯器可能會完全消除回調
def print_something():
print('printing something')
return np.int32(0)
@jax.jit
def f1():
return jax.pure_callback(print_something, np.int32(0))
f1();
printing something
@jax.jit
def f2():
jax.pure_callback(print_something, np.int32(0))
return 1.0
f2();
在 f1
中,回調的輸出在函式的傳回值中使用,因此執行回調,我們會看到列印的輸出。另一方面,在 f2
中,回調的輸出未使用,因此編譯器注意到這一點並消除了函式呼叫。這些是沒有副作用的函式的回調的正確語義。
探索 io_callback
#
與 jax.pure_callback()
相反,jax.experimental.io_callback()
明確地用於不純函式,即具有副作用的函式。
舉例來說,以下是一個回調到全域主機端 numpy 隨機產生器的範例。這是一個不純操作,因為在 numpy 中產生隨機數的副作用是隨機狀態會更新 (請注意,這僅作為 io_callback
的玩具範例,不一定是 JAX 中產生隨機數的建議方式!)。
from jax.experimental import io_callback
from functools import partial
global_rng = np.random.default_rng(0)
def host_side_random_like(x):
"""Generate a random array like x using the global_rng state"""
# We have two side-effects here:
# - printing the shape and dtype
# - calling global_rng, thus updating its state
print(f'generating {x.dtype}{list(x.shape)}')
return global_rng.uniform(size=x.shape).astype(x.dtype)
@jax.jit
def numpy_random_like(x):
return io_callback(host_side_random_like, x, x)
x = jnp.zeros(5)
numpy_random_like(x)
generating float32[5]
Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ], dtype=float32)
io_callback
預設與 vmap
相容
jax.vmap(numpy_random_like)(x)
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.91275555, 0.60663575, 0.72949654, 0.543625 , 0.9350724 ], dtype=float32)
但是請注意,這可能會以任何順序執行對應的回調。因此,舉例來說,如果您在 GPU 上執行此操作,則對應輸出的順序可能會因執行而異。
如果回調的順序必須保留,您可以設定 ordered=True
,在這種情況下,嘗試 vmap
會引發錯誤
@jax.jit
def numpy_random_like_ordered(x):
return io_callback(host_side_random_like, x, x, ordered=True)
jax.vmap(numpy_random_like_ordered)(x)
ValueError: Cannot `vmap` ordered IO callback.
另一方面,無論是否強制排序,scan
和 while_loop
都適用於 io_callback
def body_fun(_, x):
return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544], dtype=float32)
與 pure_callback
類似,如果將可微分變數傳遞給 io_callback
,則它會在自動微分下失敗
jax.grad(numpy_random_like)(x)
ValueError: IO callbacks do not support JVP.
但是,如果回調不依賴可微分變數,它將會執行
@jax.jit
def f(x):
io_callback(lambda: print('hello'), None)
return x
jax.grad(f)(1.0);
hello
與 pure_callback
不同,在這種情況下,即使回調的輸出在後續計算中未使用,編譯器也不會移除回調執行。
探索 debug.callback
#
pure_callback
和 io_callback
都對它們調用的函式的純度強制執行一些假設,並以各種方式限制 JAX 轉換和編譯機制可能執行的操作。debug.callback
本質上對回調函式不做任何假設,使得回調的動作準確地反映 JAX 在程式執行過程中執行的操作。此外,debug.callback
不能將任何值傳回程式。
from jax import debug
def log_value(x):
# This could be an actual logging call; we'll use
# print() for demonstration
print("log:", x)
@jax.jit
def f(x):
debug.callback(log_value, x)
return x
f(1.0);
log: 1.0
debug 回調與 vmap
相容
x = jnp.arange(5.0)
jax.vmap(f)(x);
log: 0.0
log: 1.0
log: 2.0
log: 3.0
log: 4.0
也與 grad
和其他自動微分轉換相容
jax.grad(f)(1.0);
log: 1.0
這使得 debug.callback
比 pure_callback
或 io_callback
更適用於通用除錯。
範例:搭配 custom_jvp
使用 pure_callback
#
利用 jax.pure_callback()
的一種強大方法是將其與 jax.custom_jvp
結合使用。(有關 jax.custom_jvp()
的更多詳細資訊,請參閱 JAX 可轉換 Python 函式的自訂導數規則)。
假設您想要為 jax.scipy
或 jax.numpy
包裝器中尚未提供的 scipy 或 numpy 函式建立 JAX 相容的包裝器。
在這裡,我們將考慮為第一類 Bessel 函式建立包裝器,該函式在 scipy.special.jv
中提供。您可以從定義簡單的 pure_callback()
開始
import jax
import jax.numpy as jnp
import scipy.special
def jv(v, z):
v, z = jnp.asarray(v), jnp.asarray(z)
# Require the order v to be integer type: this simplifies
# the JVP rule below.
assert jnp.issubdtype(v.dtype, jnp.integer)
# Promote the input to inexact (float/complex).
# Note that jnp.result_type() accounts for the enable_x64 flag.
z = z.astype(jnp.result_type(float, z.dtype))
# Wrap scipy function to return the expected dtype.
_scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)
# Define the expected shape & dtype of output.
result_shape_dtype = jax.ShapeDtypeStruct(
shape=jnp.broadcast_shapes(v.shape, z.shape),
dtype=z.dtype)
# You use vectorize=True because scipy.special.jv handles broadcasted inputs.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
這讓我們可以從轉換後的 JAX 程式碼中調用 scipy.special.jv()
,包括由 jit()
和 vmap()
轉換時
from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0)
print(j1(z))
[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]
/tmp/ipykernel_889/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
以下是使用 jit()
的相同結果
print(jax.jit(j1)(z))
[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]
/tmp/ipykernel_889/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
以下是再次使用 vmap()
的相同結果
print(jax.vmap(j1)(z))
[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]
/tmp/ipykernel_889/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
但是,如果您調用 grad()
,您會收到錯誤,因為沒有為此函式定義自動微分規則
jax.grad(j1)(z)
/tmp/ipykernel_889/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.
讓我們為此定義自訂梯度規則。查看 第一類 Bessel 函式 的定義,您會發現對於參數 z
的導數,存在相對簡單的遞迴關係
對於 \(\nu\) 的梯度更複雜,但由於我們已將 v
參數限制為整數類型,因此為了本範例的目的,您無需擔心其梯度。
您可以使用 jax.custom_jvp()
為您的回調函式定義此自動微分規則
jv = jax.custom_jvp(jv)
@jv.defjvp
def _jv_jvp(primals, tangents):
v, z = primals
_, z_dot = tangents # Note: v_dot is always 0 because v is integer.
jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
return jv(v, z), z_dot * djv_dz
現在計算函式的梯度將能正確運作
j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))
-0.06447162
/tmp/ipykernel_889/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
此外,由於我們已根據 jv
本身定義了梯度,因此 JAX 的架構意味著您可以免費獲得二階和更高階的導數
jax.hessian(j1)(2.0)
/tmp/ipykernel_889/2939642295.py:25: DeprecationWarning: The vectorized argument of jax.pure_callback is deprecated and setting it will soon raise an error. To avoid an error in the future, and to suppress this warning, please use the vmap_method argument instead.
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
Array(-0.4003078, dtype=float32, weak_type=True)
請記住,雖然這一切都與 JAX 配合良好,但每次調用基於回調的 jv
函式都會導致將輸入資料從裝置傳遞到主機,並將 scipy.special.jv()
的輸出從主機傳回裝置。
在 GPU 或 TPU 等加速器上執行時,每次調用 jv
時,這種資料移動和主機同步都可能導致顯著的額外負荷。
但是,如果您在單個 CPU 上執行 JAX (其中「主機」和「裝置」位於相同的硬體上),JAX 通常會以快速、零複製的方式執行此資料傳輸,使這種模式成為擴充 JAX 功能的相對簡單的方式。