在 JAX 中編寫自訂 Jaxpr 解譯器#
JAX 提供了數種可組合的函式轉換 (jit
、grad
、vmap
等),可讓您編寫簡潔、加速的程式碼。
在此,我們將展示如何透過編寫自訂 Jaxpr 解譯器,將您自己的函式轉換新增至系統。而且我們將免費獲得與所有其他轉換的可組合性。
此範例使用 JAX 內部 API,這些 API 可能隨時中斷。API 文件 中未提及的任何內容都應視為內部內容。
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random
JAX 在做什麼?#
JAX 提供類似 NumPy 的 API 用於數值計算,可以直接使用,但 JAX 的真正威力來自可組合的函式轉換。以 jit
函式轉換為例,它接收一個函式並傳回一個語義上相同的函式,但由 XLA 為加速器延遲編譯。
x = random.normal(random.key(0), (5000, 5000))
def f(w, b, x):
return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f)
當我們呼叫 fast_f
時,會發生什麼事?JAX 追蹤函式並建構 XLA 計算圖。然後,該圖會經過 JIT 編譯並執行。其他轉換的工作方式也類似,它們會先追蹤函式,然後以某種方式處理輸出追蹤。若要深入瞭解 Jax 的追蹤機制,您可以參考 README 中的「運作方式」章節。
Jaxpr 追蹤器#
在 Jax 中,一個特別重要的追蹤器是 Jaxpr 追蹤器,它將運算記錄到 Jaxpr (Jax 運算式) 中。Jaxpr 是一種資料結構,可以像迷你函數式程式語言一樣進行評估,因此 Jaxpr 是函式轉換的實用中繼表示法。
若要初步瞭解 Jaxpr,請考慮 make_jaxpr
轉換。make_jaxpr
本質上是一種「美化列印」轉換:它將函式轉換為一個函式,該函式在給定範例引數的情況下,會產生其計算的 Jaxpr 表示法。make_jaxpr
對於除錯和內省很有用。讓我們使用它來查看一些範例 Jaxpr 的結構。
def examine_jaxpr(closed_jaxpr):
jaxpr = closed_jaxpr.jaxpr
print("invars:", jaxpr.invars)
print("outvars:", jaxpr.outvars)
print("constvars:", jaxpr.constvars)
for eqn in jaxpr.eqns:
print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
print()
print("jaxpr:", jaxpr)
def foo(x):
return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))
print()
def bar(w, b, x):
return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))
foo
=====
invars: [Var(id=140424678569344):int32[]]
outvars: [Var(id=140424678569408):int32[]]
constvars: []
equation: [Var(id=140424678569344):int32[], 1] add [Var(id=140424678569408):int32[]] {}
jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }
bar
=====
invars: [Var(id=140424679042944):float32[5,10], Var(id=140424679043008):float32[5], Var(id=140424679043072):float32[10]]
outvars: [Var(id=140424679043712):float32[5], Var(id=140424679043072):float32[10]]
constvars: []
equation: [Var(id=140424679042944):float32[5,10], Var(id=140424679043072):float32[10]] dot_general [Var(id=140424679043520):float32[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32'), 'out_sharding': None}
equation: [Var(id=140424679043520):float32[5], Var(id=140424679043008):float32[5]] add [Var(id=140424679043584):float32[5]] {}
equation: [1.0] broadcast_in_dim [Var(id=140424679043648):float32[5]] {'shape': (5,), 'broadcast_dimensions': (), 'sharding': None}
equation: [Var(id=140424679043584):float32[5], Var(id=140424679043648):float32[5]] add [Var(id=140424679043712):float32[5]] {}
jaxpr: { lambda ; a:f32[5,10] b:f32[5] c:f32[10]. let
d:f32[5] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] a c
e:f32[5] = add d b
f:f32[5] = broadcast_in_dim[
broadcast_dimensions=()
shape=(5,)
sharding=None
] 1.0
g:f32[5] = add e f
in (g, c) }
jaxpr.invars
- Jaxpr 的invars
是 Jaxpr 的輸入變數列表,類似於 Python 函式中的引數。jaxpr.outvars
- Jaxpr 的outvars
是 Jaxpr 傳回的變數。每個 Jaxpr 都有多個輸出。jaxpr.constvars
-constvars
是變數列表,這些變數也是 Jaxpr 的輸入,但對應於追蹤中的常數 (稍後我們將更詳細地介紹這些常數)。jaxpr.eqns
- 方程式列表,本質上是 let 綁定。每個方程式都是輸入變數列表、輸出變數列表和基本運算,基本運算用於評估輸入以產生輸出。每個方程式也有一個params
,即參數字典。
總之,Jaxpr 封裝了一個簡單的程式,可以使用輸入進行評估以產生輸出。稍後我們將介紹如何精確地執行此操作。現在需要注意的重要事項是,Jaxpr 是一種資料結構,可以按照我們想要的任何方式進行操作和評估。
為什麼 Jaxpr 有用?#
Jaxpr 是簡單的程式表示法,易於轉換。而且由於 Jax 讓我們可以從 Python 函式中暫存 Jaxpr,因此它為我們提供了一種轉換以 Python 撰寫的數值程式的方法。
您的第一個解譯器:invert
#
讓我們嘗試實作一個簡單的「反轉器」函式,該函式接收原始函式的輸出,並傳回產生這些輸出的輸入。現在,讓我們專注於由其他可反轉的單元函式組成的簡單單元函式。
目標
def f(x):
return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)
我們將實作此函式的方式是 (1) 將 f
追蹤到 Jaxpr 中,然後 (2)反向解譯 Jaxpr。在反向解譯 Jaxpr 時,對於每個方程式,我們都會在表格中查閱基本運算的反函數並套用它。
1. 追蹤函式#
讓我們使用 make_jaxpr
將函式追蹤到 Jaxpr 中。
# Importing Jax functions useful for tracing/interpreting.
from functools import wraps
from jax import core
from jax import lax
from jax._src.util import safe_map
jax.make_jaxpr
傳回封閉的 Jaxpr,這是一種已與追蹤中的常數 (literals
) 捆綁在一起的 Jaxpr。
def f(x):
return jnp.exp(jnp.tanh(x))
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)
{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) }
[]
2. 評估 Jaxpr#
在我們編寫自訂 Jaxpr 解譯器之前,讓我們先實作「預設」解譯器 eval_jaxpr
,它會按原樣評估 Jaxpr,計算與原始未轉換的 Python 函式相同的值。
若要執行此操作,我們首先建立一個環境來儲存每個變數的值,並在我們評估 Jaxpr 中的每個方程式時更新環境。
def eval_jaxpr(jaxpr, consts, *args):
# Mapping from variable -> value
env = {}
def read(var):
# Literals are values baked into the Jaxpr
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# Bind args and consts to environment
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)
# Loop through equations and evaluate primitives using `bind`
for eqn in jaxpr.eqns:
# Read inputs to equation from environment
invals = safe_map(read, eqn.invars)
# `bind` is how a primitive is called
outvals = eqn.primitive.bind(*invals, **eqn.params)
# Primitives may return multiple outputs or not
if not eqn.primitive.multiple_results:
outvals = [outvals]
# Write the results of the primitive into the environment
safe_map(write, eqn.outvars, outvals)
# Read the final result of the Jaxpr from the environment
return safe_map(read, jaxpr.outvars)
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
/tmp/ipykernel_1231/3734673940.py:7: DeprecationWarning: jax.core.Literal is deprecated. Use jax.extend.core.Literal instead, and see https://jax.dev.org.tw/en/latest/jax.extend.html for details.
if type(var) is core.Literal:
[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]
請注意,即使原始函式沒有,eval_jaxpr
也始終會傳回平面列表。
此外,此解譯器不處理高階基本運算 (例如 jit
和 pmap
),我們將在本指南中不介紹這些內容。您可以參考 core.eval_jaxpr
(連結) 以查看此解譯器未涵蓋的邊緣案例。
自訂 inverse
Jaxpr 解譯器#
inverse
解譯器看起來與 eval_jaxpr
沒有太大差異。我們首先設定登錄檔,將基本運算對應到它們的反函數。然後,我們將編寫一個自訂解譯器,在登錄檔中查閱基本運算。
事實證明,此解譯器也將與反向模式自動微分中使用的「轉置」解譯器類似,在此處找到。
inverse_registry = {}
我們現在將為某些基本運算註冊反函數。依照慣例,Jax 中的基本運算以 _p
結尾,許多常用的基本運算都位於 lax
中。
inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh
inverse
將首先追蹤函式,然後自訂解譯 Jaxpr。讓我們設定一個簡單的骨架。
def inverse(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
# Since we assume unary functions, we won't worry about flattening and
# unflattening arguments.
closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out[0]
return wrapped
現在我們只需要定義 inverse_jaxpr
,它將向後遍歷 Jaxpr,並在可能的情況下反轉基本運算。
def inverse_jaxpr(jaxpr, consts, *args):
env = {}
def read(var):
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# Args now correspond to Jaxpr outvars
safe_map(write, jaxpr.outvars, args)
safe_map(write, jaxpr.constvars, consts)
# Looping backward
for eqn in jaxpr.eqns[::-1]:
# outvars are now invars
invals = safe_map(read, eqn.outvars)
if eqn.primitive not in inverse_registry:
raise NotImplementedError(
f"{eqn.primitive} does not have registered inverse.")
# Assuming a unary function
outval = inverse_registry[eqn.primitive](*invals)
safe_map(write, eqn.invars, [outval])
return safe_map(read, jaxpr.invars)
就是這樣!
def f(x):
return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)
重要的是,您可以追蹤 Jaxpr 解譯器。
jax.make_jaxpr(inverse(f))(f(1.))
{ lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }
只需這樣做,即可將新的轉換新增至系統,而且您可以免費獲得與所有其他轉換的組合!例如,我們可以將 jit
、vmap
和 grad
與 inverse
一起使用!
jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)
Array([-3.1440797, 15.584931 , 2.2551253, 1.3155028, 1. ], dtype=float32, weak_type=True)
讀者練習#
處理具有多個引數且輸入部分已知的情況的基本運算,例如
lax.add_p
、lax.mul_p
。處理
xla_call
和xla_pmap
基本運算,這些基本運算無法與已編寫的eval_jaxpr
和inverse_jaxpr
一起運作。