Autodidax:從零開始的 JAX 核心#
是否曾經想學習 JAX 的運作方式,但又覺得實作難以理解?您真幸運!透過閱讀本教學,您將學習 JAX 核心系統中的每一個重要概念。您甚至會了解我們奇怪的術語!
這是一個正在進行中的草稿。 仍然缺少一些重要的組成部分,將在第 5 部分和第 6 部分(以及更多?)中推出。這裡也有一些簡化,我們尚未應用於主要系統,但我們會的。
第 1 部分:作為直譯器的轉換:標準求值、jvp
和 vmap
#
我們想要轉換看起來像這樣的函數
def f(x):
y = sin(x) * 2.
z = - y + x
return z
將像 sin
這樣的函數和作為中綴運算子基礎的算術運算(mul
、add
和 neg
)視為基本運算,意思是處理的原子單位,而不是組合。
「轉換」的意思是「以不同方式解釋」。我們想要覆寫基本運算的應用,讓不同的值流經我們的程式,而不是標準的解釋,即將基本運算應用於數值輸入以產生數值輸出。例如,我們可能想要將每個基本運算的應用替換為應用其 JVP 規則,並讓原始值-切線對流經我們的程式。此外,我們希望能夠組合多個轉換,從而產生直譯器堆疊。
JAX 核心機制#
我們可以實作直譯器堆疊,甚至讓它們在我們執行要轉換的 Python 函數時即時全部卸載。首先,讓我們定義這些基本運算,以便我們可以攔截它們的應用
from typing import NamedTuple
class Primitive(NamedTuple):
name: str
add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")
def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
if axis is None:
axis = tuple(range(np.ndim(x)))
if type(axis) is int:
axis = (axis,)
return bind1(reduce_sum_p, x, axis=axis)
def bind1(prim, *args, **params):
out, = bind(prim, *args, **params)
return out
我們稍後會設定陣列資料類型和中綴運算子方法。
Primitive
只是一個具有名稱的物件,我們將解釋規則附加到該物件(每個轉換一個)。bind
函數是我們的攔截點:它會根據引數如何封裝在追蹤器中以及哪些直譯器處於活動狀態,來判斷要應用哪個轉換規則。
使用者程式碼呼叫的函數,例如 add
和 sin
,只是對 bind
呼叫的包裝函式。這些包裝函式讓我們可以控制引數如何傳遞給 bind
,特別是我們遵循一個方便的內部慣例:當我們呼叫 bind
時,我們將表示陣列資料的值作為位置引數傳遞,並透過關鍵字傳遞元資料,例如 axis
引數給 sum_p
。這種呼叫慣例簡化了一些核心邏輯(因為例如,要定義的 Tracer
類別的實例只能出現在 bind
的位置引數中)。包裝函式也可以提供文件字串!
我們將活動直譯器表示為堆疊。堆疊只是一個簡單的 list
,每個元素都是一個容器,包含一個整數層級(對應於堆疊中元素的高度)、一個直譯器類型(我們將其稱為 trace_type
)和一個可選欄位,用於直譯器需要的任何全域資料。我們將每個元素稱為 MainTrace
,儘管「直譯器」可能更具描述性。
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any
class MainTrace(NamedTuple):
level: int
trace_type: type['Trace']
global_data: Any | None
trace_stack: list[MainTrace] = []
dynamic_trace: MainTrace | None = None # to be employed in Part 3
@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):
level = len(trace_stack)
main = MainTrace(level, trace_type, global_data)
trace_stack.append(main)
try:
yield main
finally:
trace_stack.pop()
當我們即將應用轉換時,我們將使用 new_main
將另一個直譯器推入堆疊。然後,當我們在函數中應用基本運算時,我們可以將 bind
視為首先由堆疊頂部的追蹤(即具有最高層級的追蹤)直譯。如果第一個直譯器本身在其基本運算的解釋規則中繫結了其他基本運算,例如 sin_p
的 JVP 規則可能繫結 cos_p
和 mul_p
,那麼這些 bind
呼叫將由下一層級的直譯器處理。
直譯器堆疊的底部是什麼?在底部,我們知道所有轉換直譯器都已完成,我們只想執行標準求值。因此,在底部,我們將放置一個求值直譯器。
讓我們草擬直譯器的介面,該介面基於 Trace
和 Tracer
基底類別。Tracer
表示一個封裝的值,可能攜帶直譯器使用的一些額外上下文資料。Trace
處理將值封裝到 Tracer
中,並處理基本運算應用。
class Trace:
main: MainTrace
def __init__(self, main: MainTrace) -> None:
self.main = main
def pure(self, val): assert False # must override
def lift(self, val): assert False # must override
def process_primitive(self, primitive, tracers, params):
assert False # must override
前兩個方法是關於在 Tracer
中封裝值,Tracer
是流經我們轉換的 Python 程式的物件。最後一個方法是我們將用於解釋基本運算應用的回呼。
Trace
本身不包含任何資料,除了對其對應的 MainTrace
實例的引用。實際上,在應用轉換期間,可能會建立和丟棄 Trace
的多個實例,而每次應用轉換只會建立一個 MainTrace
實例。
至於 Tracer
本身,每個都攜帶一個抽象值(並將中綴運算子轉發給它),其餘的取決於轉換。(Tracer
和 AbstractValue
之間的關係是,每個轉換有一個 Tracer
,每個基本類型至少有一個 AbstractValue
,例如陣列。)
import numpy as np
class Tracer:
_trace: Trace
__array_priority__ = 1000
@property
def aval(self):
assert False # must override
def full_lower(self):
return self # default implementation
def __neg__(self): return self.aval._neg(self)
def __add__(self, other): return self.aval._add(self, other)
def __radd__(self, other): return self.aval._radd(self, other)
def __mul__(self, other): return self.aval._mul(self, other)
def __rmul__(self, other): return self.aval._rmul(self, other)
def __gt__(self, other): return self.aval._gt(self, other)
def __lt__(self, other): return self.aval._lt(self, other)
def __bool__(self): return self.aval._bool(self)
def __nonzero__(self): return self.aval._nonzero(self)
def __getattr__(self, name):
try:
return getattr(self.aval, name)
except AttributeError:
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
def swap(f): return lambda x, y: f(y, x)
class ShapedArray:
array_abstraction_level = 1
shape: tuple[int, ...]
dtype: np.dtype
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
@property
def ndim(self):
return len(self.shape)
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(swap(add))
_mul = staticmethod(mul)
_rmul = staticmethod(swap(mul))
_gt = staticmethod(greater)
_lt = staticmethod(less)
@staticmethod
def _bool(tracer):
raise Exception("ShapedArray can't be unambiguously converted to bool")
@staticmethod
def _nonzero(tracer):
raise Exception("ShapedArray can't be unambiguously converted to bool")
def str_short(self):
return f'{self.dtype.name}[{",".join(str(d) for d in self.shape)}]'
def __hash__(self):
return hash((self.shape, self.dtype))
def __eq__(self, other):
return (type(self) is type(other) and
self.shape == other.shape and self.dtype == other.dtype)
def __repr__(self):
return f"ShapedArray(shape={self.shape}, dtype={self.dtype})"
class ConcreteArray(ShapedArray):
array_abstraction_level = 2
val: np.ndarray
def __init__(self, val):
self.val = val
self.shape = val.shape
self.dtype = val.dtype
@staticmethod
def _bool(tracer):
return bool(tracer.aval.val)
@staticmethod
def _nonzero(tracer):
return bool(tracer.aval.val)
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
elif type(x) in jax_types:
return ConcreteArray(np.asarray(x))
else:
raise TypeError(x)
jax_types = {bool, int, float,
np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}
請注意,我們實際上為陣列設定了兩個 AbstractValue
,表示不同的抽象層級。ShapedArray
表示具有給定形狀和 dtype 的所有可能陣列的集合。ConcreteArray
表示由單個陣列值組成的單例集合。
現在我們已經設定了直譯器堆疊、直譯器的 Trace/Tracer API 和抽象值,我們可以回來實作 bind
def bind(prim, *args, **params):
top_trace = find_top_trace(args)
tracers = [full_raise(top_trace, arg) for arg in args]
outs = top_trace.process_primitive(prim, tracers, params)
return [full_lower(out) for out in outs]
主要操作是我們呼叫 find_top_trace
以找出哪個直譯器應處理此基本運算應用。然後我們呼叫頂層追蹤的 process_primitive
,以便追蹤可以應用其解釋規則。呼叫 full_raise
只是為了確保輸入封裝在頂層追蹤的 Tracer
實例中,而呼叫 full_lower
是一個可選的最佳化,以便我們盡可能多地從 Tracer
中取消封裝值。
import operator as op
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=trace_stack[0], key=op.attrgetter('level'))
if dynamic_trace and dynamic_trace.level > top_main.level:
top_main = dynamic_trace
return top_main.trace_type(top_main)
用文字來說,忽略第 3 部分之前的 dynamic_trace
步驟,find_top_trace
會傳回與其輸入上的 Tracer
相關聯的最高層級直譯器,否則會傳回堆疊底部的直譯器(至少現在始終是一個求值追蹤)。這與上面的描述有所不同,在上面的描述中,我們始終從執行堆疊頂部的直譯器開始,然後逐步向下,應用堆疊中的每個直譯器。相反,我們僅在基本運算繫結的輸入引數封裝在對應於該直譯器的 Tracer
中時才應用直譯器。這種最佳化讓我們可以跳過不相關的轉換,但會預設轉換主要遵循資料依賴性(除了特殊的堆疊底部直譯器,它會解釋所有內容)。
另一種方法是讓堆疊中的每個直譯器都解釋每個運算。這值得探索!JAX 的設計在很大程度上圍繞資料依賴性,因為這對於自動微分來說非常自然,而 JAX 的根源在於自動微分。但它可能過度擬合。
def full_lower(val: Any):
if isinstance(val, Tracer):
return val.full_lower()
else:
return val
def full_raise(trace: Trace, val: Any) -> Tracer:
if not isinstance(val, Tracer):
assert type(val) in jax_types
return trace.pure(val)
level = trace.main.level
if val._trace.main is trace.main:
return val
elif val._trace.main.level < level:
return trace.lift(val)
elif val._trace.main.level > level:
raise Exception(f"Can't lift level {val._trace.main.level} to {level}.")
else: # val._trace.level == level
raise Exception(f"Different traces at same level: {val._trace}, {trace}.")
full_raise
中的邏輯用於將值封裝到特定 Trace
的 Tracer
中,並根據上下文呼叫 Trace
上的不同方法:Trace.pure
在非 Tracer
常數上呼叫,Trace.lift
在已經是較低層級直譯器的 Tracer
的值上呼叫。這兩種方法可以共用相同的實作,但透過在核心邏輯中區分它們,我們可以為 Trace
子類別提供更多資訊。
這就是 JAX 核心的全部內容!現在我們可以開始新增直譯器了。
求值直譯器#
我們將從最簡單的直譯器開始:將位於直譯器堆疊底部的求值直譯器。
class EvalTrace(Trace):
pure = lift = lambda self, x: x # no boxing in Tracers needed
def process_primitive(self, primitive, tracers, params):
return impl_rules[primitive](*tracers, **params)
trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack
# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance
impl_rules = {}
impl_rules[add_p] = lambda x, y: [np.add(x, y)]
impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]
impl_rules[neg_p] = lambda x: [np.negative(x)]
impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
def broadcast_impl(x, *, shape, axes):
for axis in sorted(axes):
x = np.expand_dims(x, axis)
return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl
使用此直譯器,我們可以求值使用者函數
def f(x):
y = sin(x) * 2.
z = - y + x
return z
print(f(3.0))
2.7177599838802657
哇!就像在一個大圈子裡繞行。但這種間接的目的在於,現在我們可以新增一些真正的轉換。
使用 jvp
的前向模式自動微分#
首先,是一些輔助函數
import builtins
def zeros_like(val):
aval = get_aval(val)
return np.zeros(aval.shape, aval.dtype)
def unzip2(pairs):
lst1, lst2 = [], []
for x1, x2 in pairs:
lst1.append(x1)
lst2.append(x2)
return lst1, lst2
def map(f, *xs):
return list(builtins.map(f, *xs))
def zip(*args):
fst, *rest = args = map(list, args)
n = len(fst)
for arg in rest:
assert len(arg) == n
return list(builtins.zip(*args))
前向模式自動微分的 Tracer
攜帶原始值-切線對。Trace
應用 JVP 規則。
class JVPTracer(Tracer):
def __init__(self, trace, primal, tangent):
self._trace = trace
self.primal = primal
self.tangent = tangent
@property
def aval(self):
return get_aval(self.primal)
class JVPTrace(Trace):
pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))
def process_primitive(self, primitive, tracers, params):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
jvp_rule = jvp_rules[primitive]
primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)
return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]
jvp_rules = {}
請注意,pure
和 lift
都將值封裝到 JVPTracer
中,並帶有最少的上下文,即零切線值。
讓我們為基本運算新增一些 JVP 規則
def add_jvp(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return [x + y], [x_dot + y_dot]
jvp_rules[add_p] = add_jvp
def mul_jvp(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return [x * y], [x_dot * y + x * y_dot]
jvp_rules[mul_p] = mul_jvp
def sin_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [sin(x)], [cos(x) * x_dot]
jvp_rules[sin_p] = sin_jvp
def cos_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [cos(x)], [-sin(x) * x_dot]
jvp_rules[cos_p] = cos_jvp
def neg_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [neg(x)], [neg(x_dot)]
jvp_rules[neg_p] = neg_jvp
def reduce_sum_jvp(primals, tangents, *, axis):
(x,), (x_dot,) = primals, tangents
return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]
jvp_rules[reduce_sum_p] = reduce_sum_jvp
def greater_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = greater(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp
def less_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = less(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp
最後,我們新增一個轉換 API 來啟動追蹤
def jvp_v1(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
out = f(*tracers_in)
tracer_out = full_raise(trace, out)
primal_out, tangent_out = tracer_out.primal, tracer_out.tangent
return primal_out, tangent_out
有了這個,我們就可以微分了!
x = 3.0
y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))
print(sin_deriv_at_3)
print(cos(3.0))
-0.9899924966004454
-0.9899924966004454
def f(x):
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp_v1(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
def deriv(f):
return lambda x: jvp_v1(f, (x,), (1.,))[1]
print(deriv(sin)(3.))
print(deriv(deriv(sin))(3.))
print(deriv(deriv(deriv(sin)))(3.))
print(deriv(deriv(deriv(deriv(sin))))(3.))
-0.9899924966004454
-0.1411200080598672
0.9899924966004454
0.1411200080598672
def f(x):
if x > 0.: # Python control flow
return 2. * x
else:
return x
print(deriv(f)(3.))
print(deriv(f)(-3.))
2.0
1.0
Pytrees 和扁平化使用者函數的輸入和輸出#
jvp_v1
的一個限制是它假設使用者函數接受陣列作為位置引數,並產生單個陣列作為輸出。如果它產生一個列表作為輸出怎麼辦?或者接受巢狀容器作為輸入?在堆疊的每一層處理輸入和輸出中的所有可能容器將會很麻煩。相反,我們可以包裝使用者函數,以便包裝後的版本接受陣列作為輸入,並傳回扁平的陣列列表作為輸出。包裝函式只需要取消扁平化其輸入、呼叫使用者函數並扁平化輸出。
假設使用者始終為我們提供以陣列作為輸入並產生扁平陣列列表作為輸出的函數,以下是我們想要編寫 jvp
的方式
def jvp_flat(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
return primals_out, tangents_out
為了支援輸入和輸出中具有任意容器的使用者函數,以下是我們編寫面向使用者的 jvp
包裝函式的方式
def jvp(f, primals, tangents):
primals_flat, in_tree = tree_flatten(primals)
tangents_flat, in_tree2 = tree_flatten(tangents)
if in_tree != in_tree2: raise TypeError
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
tangents_out = tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, tangents_out
請注意,我們必須將使用者函數輸出的樹狀結構導回 flatten_fun
的呼叫者。在我們實際執行使用者函數之前,該資訊不可用,因此 flatten_fun
只傳回對可變單元格的引用,表示為 thunk。這些副作用是安全的,因為我們始終只執行一次使用者函數。(這種安全機制是 linear_util.py
中「linear」名稱的原因,在線性類型的意義上。)
剩下的就是編寫 tree_flatten
、tree_unflatten
和 flatten_fun
。
顯示程式碼單元來源
def flatten_fun(f, in_tree):
store = Store()
def flat_fun(*args_flat):
pytree_args = tree_unflatten(in_tree, args_flat)
out = f(*pytree_args)
out_flat, out_tree = tree_flatten(out)
store.set_value(out_tree)
return out_flat
return flat_fun, store
class Empty: pass
empty = Empty()
class Store:
val = empty
def set_value(self, val):
assert self.val is empty
self.val = val
def __call__(self):
return self.val
顯示程式碼單元來源
from collections.abc import Hashable, Iterable, Iterator
import itertools as it
from collections.abc import Callable
class NodeType(NamedTuple):
name: str
to_iterable: Callable
from_iterable: Callable
def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable
) -> None:
node_types[ty] = NodeType(str(ty), to_iter, from_iter)
node_types: dict[type, NodeType] = {}
register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))
register_pytree_node(list, lambda l: (None, l), lambda _, xs: list(xs))
register_pytree_node(dict,
lambda d: map(tuple, unzip2(sorted(d.items()))),
lambda keys, vals: dict(zip(keys, vals)))
class PyTreeDef(NamedTuple):
node_type: NodeType
node_metadata: Hashable
child_treedefs: tuple['PyTreeDef', ...]
class Leaf: pass
leaf = Leaf()
def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
children_iter, treedef = _tree_flatten(x)
return list(children_iter), treedef
def _tree_flatten(x: Any) -> tuple[Iterable, PyTreeDef]:
node_type = node_types.get(type(x))
if node_type:
node_metadata, children = node_type.to_iterable(x)
children_flat, child_trees = unzip2(map(_tree_flatten, children))
flattened = it.chain.from_iterable(children_flat)
return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))
else:
return [x], leaf
def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
return _tree_unflatten(treedef, iter(xs))
def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any:
if treedef is leaf:
return next(xs)
else:
children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
return treedef.node_type.from_iterable(treedef.node_metadata, children)
透過這種處理 pytree 的 jvp
實作,我們現在可以處理任意輸入和輸出容器。這在未來的轉換中也會派上用場!
def f(x):
y = sin(x) * 2.
z = - y + x
return {'hi': z, 'there': [x, y]}
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
{'hi': np.float64(2.7177599838802657), 'there': [3.0, np.float64(0.2822400161197344)]}
{'hi': np.float64(2.979984993200891), 'there': [1.0, np.float64(-1.9799849932008908)]}
使用 vmap
的向量化批次處理#
首先,是幾個輔助函數,一個用於從未映射的值產生映射的抽象值(透過移除軸),另一個用於移動批次維度
def mapped_aval(batch_dim, aval):
shape = list(aval.shape)
del shape[batch_dim]
return ShapedArray(tuple(shape), aval.dtype)
def move_batch_axis(axis_size, src, dst, x):
if src is not_mapped:
target_shape = list(np.shape(x))
target_shape.insert(dst, axis_size)
return broadcast(x, target_shape, [dst])
elif src == dst:
return x
else:
return moveaxis(x, src, dst)
def moveaxis(x, src: int, dst: int):
perm = [i for i in range(np.ndim(x)) if i != src]
perm.insert(dst, src)
return transpose(x, perm)
向量化批次處理的 Tracer
攜帶批次處理的值和一個可選的整數,指示哪個軸(如果有的話)是批次軸。
from typing import Union
class NotMapped: pass
not_mapped = NotMapped()
BatchAxis = Union[NotMapped, int]
class BatchTracer(Tracer):
def __init__(self, trace, val, batch_dim: BatchAxis):
self._trace = trace
self.val = val
self.batch_dim = batch_dim
@property
def aval(self):
if self.batch_dim is not_mapped:
return get_aval(self.val)
else:
return mapped_aval(self.batch_dim, get_aval(self.val))
def full_lower(self):
if self.batch_dim is not_mapped:
return full_lower(self.val)
else:
return self
class BatchTrace(Trace):
pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)
def process_primitive(self, primitive, tracers, params):
vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)
vmap_rule = vmap_rules[primitive]
val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)
return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]
@property
def axis_size(self):
return self.main.global_data
vmap_rules = {}
在這裡,我們實作了可選的 Tracer.full_lower
方法,如果不需要批次處理追蹤器,因為它不表示批次處理的值,則該方法允許我們剝離批次處理追蹤器。
對於 BatchTrace
,類似於 JVPTrace
,方法 pure
和 lift
只是將值封裝在 BatchTracer
中,並帶有最少的上下文,在本例中,上下文是一個 batch_dim
,它採用哨兵值 not_mapped
。請注意,我們使用 MainTrace
的直譯器全域資料欄位來儲存批次軸大小。
接下來,我們可以為每個基本運算定義批次處理直譯器規則
from functools import partial
def binop_batching_rule(op, axis_size, vals_in, dims_in):
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
if x_bdim != y_bdim:
if x_bdim is not_mapped:
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
x_bdim = y_bdim
else:
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
(x,), (x_bdim,) = vals_in, dims_in
return [op(x)], [x_bdim]
vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)
vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)
vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)
def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
(x,), (x_bdim,) = vals_in, dims_in
new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule
最後,我們新增一個轉換 API 來啟動追蹤
def vmap_flat(f, in_axes, *args):
axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)
if ax is not not_mapped}
with new_main(BatchTrace, axis_size) as main:
trace = BatchTrace(main)
tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x
for x, ax in zip(args, in_axes)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)
outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)
for val_out, bdim in zip(vals_out, bdims_out)]
return outs_transposed
def vmap(f, in_axes):
def batched_f(*args):
args_flat, in_tree = tree_flatten(args)
in_axes_flat, in_tree2 = tree_flatten(in_axes)
if in_tree != in_tree2: raise TypeError
f_flat, out_tree = flatten_fun(f, in_tree)
outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
return tree_unflatten(out_tree(), outs_flat)
return batched_f
def add_one_to_a_scalar(scalar):
assert np.ndim(scalar) == 0
return 1 + scalar
vector_in = np.arange(3.)
vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)
print(vector_in)
print(vector_out)
[0. 1. 2.]
[1. 2. 3.]
def jacfwd(f, x):
pushfwd = lambda v: jvp(f, (x,), (v,))[1]
vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)
return vmap(pushfwd, (0,))(vecs_in)
def f(x):
return sin(x)
jacfwd(f, np.arange(3.))
array([[ 1. , 0. , -0. ],
[ 0. , 0.54030231, -0. ],
[ 0. , 0. , -0.41614684]])
這就是 jvp
和 vmap
的全部內容!
第 2 部分:Jaxprs#
接下來的轉換是即時編譯的 jit
和反向模式自動微分的 vjp
。(grad
只是 vjp
周圍的一個小型包裝函式。)雖然 jvp
和 vmap
只需要每個 Tracer
攜帶一點額外的上下文,但對於 jit
和 vjp
,我們需要更豐富的上下文:我們需要表示程式。也就是說,我們需要 jaxprs!
Jaxprs 是 JAX 程式的內部中介表示。它們是顯式類型、函數式、一階且採用 ANF 形式。我們需要 jit
的程式表示,因為 jit
的目的是將計算分階段移出 Python。對於我們想要分階段移出的任何計算,我們需要能夠將其表示為資料,並在我們追蹤 Python 函數時建構它。同樣,vjp
需要一種方式來表示反向模式自動微分的反向傳遞的計算。我們為這兩種需求使用相同的 jaxpr 程式表示。
(建構程式表示是最自由的追蹤轉換類型,因此,除了處理原生 Python 控制流程的問題外,任何轉換都可以透過首先追蹤到 jaxpr,然後解釋 jaxpr 來實作。)
Jaxpr 資料結構#
jaxpr 術語語法大致如下
jaxpr ::=
{ lambda <binder> , ... .
let <eqn>
...
in ( <atom> , ... ) }
binder ::= <var>:<array_type>
var ::= a | b | c | ...
atom ::= <var> | <literal>
literal ::= <int32> | <int64> | <float32> | <float64>
eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...
類型的語法如下
jaxpr_type ::= [ <array_type> , ... ] -> [ <array_type> , ... ]
array_type ::= <dtype>[<shape>]
dtype ::= f32 | f64 | i32 | i64
shape ::= <int> , ...
我們如何將這些表示為 Python 資料結構?我們重複使用 ShapedArrays 來表示類型,並且可以使用一些 Python 結構來表示術語語法
class Var:
aval: ShapedArray
def __init__(self, aval): self.aval = aval
class Lit:
val: Any
aval: ShapedArray
def __init__(self, val):
self.aval = aval = raise_to_shaped(get_aval(val))
self.val = np.array(val, aval.dtype)
Atom = Union[Var, Lit]
class JaxprEqn(NamedTuple):
primitive: Primitive
inputs: list[Atom]
params: dict[str, Any]
out_binders: list[Var]
class Jaxpr(NamedTuple):
in_binders: list[Var]
eqns: list[JaxprEqn]
outs: list[Atom]
def __hash__(self): return id(self)
__eq__ = op.is_
def raise_to_shaped(aval):
return ShapedArray(aval.shape, aval.dtype)
類型檢查 jaxpr 涉及檢查是否沒有未綁定的變數、變數是否僅綁定一次,以及對於每個方程式,原始運算的類型是否與輸出綁定器的類型匹配。
class JaxprType(NamedTuple):
in_types: list[ShapedArray]
out_types: list[ShapedArray]
def __repr__(self):
in_types = ', '.join(aval.str_short() for aval in self.in_types)
out_types = ', '.join(aval.str_short() for aval in self.out_types)
return f'({in_types}) -> ({out_types})'
def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:
env: set[Var] = set()
for v in jaxpr.in_binders:
if v in env: raise TypeError
env.add(v)
for eqn in jaxpr.eqns:
in_types = [typecheck_atom(env, x) for x in eqn.inputs]
out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)
for out_binder, out_type in zip(eqn.out_binders, out_types):
if not out_type == out_binder.aval: raise TypeError
for out_binder in eqn.out_binders:
if out_binder in env: raise TypeError
env.add(out_binder)
in_types = [v.aval for v in jaxpr.in_binders]
out_types = [typecheck_atom(env, x) for x in jaxpr.outs]
return JaxprType(in_types, out_types)
def typecheck_atom(env: set[Var], x: Atom) -> ShapedArray:
if isinstance(x, Var):
if x not in env: raise TypeError("unbound variable")
return x.aval
elif isinstance(x, Lit):
return raise_to_shaped(get_aval(x.val))
else:
assert False
我們可以透過一個簡單的直譯器,將 jaxpr 表示的函數應用於引數。
def eval_jaxpr(jaxpr: Jaxpr, args: list[Any]) -> list[Any]:
env: dict[Var, Any] = {}
def read(x: Atom) -> Any:
return env[x] if type(x) is Var else x.val
def write(v: Var, val: Any) -> None:
assert v not in env # single-assignment
env[v] = val
map(write, jaxpr.in_binders, args)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.inputs)
outs = bind(eqn.primitive, *in_vals, **eqn.params)
map(write, eqn.out_binders, outs)
return map(read, jaxpr.outs)
def jaxpr_as_fun(jaxpr: Jaxpr):
return lambda *args: eval_jaxpr(jaxpr, args)
透過在直譯器中使用 bind
,這個直譯器本身是可追蹤的。
使用追蹤建構 jaxpr#
既然我們有了 jaxpr 作為資料結構,我們需要方法從追蹤 Python 程式碼來產生它們。一般來說,有兩種變體可以追蹤到 jaxpr;jit
使用其中一種,而 vjp
使用另一種。我們將從 jit
使用的那種開始,它也被控制流程原語使用,例如 lax.cond
、lax.while_loop
和 lax.scan
。
def split_list(lst: list[Any], n: int) -> tuple[list[Any], list[Any]]:
assert 0 <= n <= len(lst)
return lst[:n], lst[n:]
def partition_list(bs: list[bool], l: list[Any]) -> tuple[list[Any], list[Any]]:
assert len(bs) == len(l)
lists = lst1, lst2 = [], []
for b, x in zip(bs, l):
lists[b].append(x)
return lst1, lst2
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'
class JaxprTracer(Tracer):
__slots__ = ['aval']
aval: ShapedArray
def __init__(self, trace, aval):
self._trace = trace
self.aval = aval
# NB: the analogous class in JAX is called 'DynamicJaxprTrace'
class JaxprTrace(Trace):
def new_arg(self, aval: ShapedArray) -> JaxprTracer:
aval = raise_to_shaped(aval)
tracer = self.builder.new_tracer(self, aval)
self.builder.tracer_to_var[id(tracer)] = Var(aval)
return tracer
def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:
tracer = self.builder.const_tracers.get(id(val))
if tracer is None:
tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))
self.builder.add_const(tracer, val)
return tracer
pure = lift = get_or_make_const_tracer
def process_primitive(self, primitive, tracers, params):
avals_in = [t.aval for t in tracers]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]
inputs = [self.builder.getvar(t) for t in tracers]
outvars = [self.builder.add_var(t) for t in out_tracers]
self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))
return out_tracers
@property
def builder(self):
return self.main.global_data
# NB: in JAX, we instead attach abstract eval rules to Primitive instances
abstract_eval_rules = {}
請注意,我們將建構器物件作為直譯器全域資料保留,它會在我們建構 jaxpr 時追蹤變數、常數和方程式。
class JaxprBuilder:
eqns: list[JaxprEqn]
tracer_to_var: dict[int, Var]
const_tracers: dict[int, JaxprTracer]
constvals: dict[Var, Any]
tracers: list[JaxprTracer]
def __init__(self):
self.eqns = []
self.tracer_to_var = {}
self.const_tracers = {}
self.constvals = {}
self.tracers = []
def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer:
tracer = JaxprTracer(trace, aval)
self.tracers.append(tracer)
return tracer
def add_eqn(self, eqn: JaxprEqn) -> None:
self.eqns.append(eqn)
def add_var(self, tracer: JaxprTracer) -> Var:
assert id(tracer) not in self.tracer_to_var
var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)
return var
def getvar(self, tracer: JaxprTracer) -> Var:
var = self.tracer_to_var.get(id(tracer))
assert var is not None
return var
def add_const(self, tracer: JaxprTracer, val: Any) -> Var:
var = self.add_var(tracer)
self.const_tracers[id(val)] = tracer
self.constvals[var] = val
return var
def build(self, in_tracers: list[JaxprTracer], out_tracers: list[JaxprTracer]
) -> tuple[Jaxpr, list[Any]]:
constvars, constvals = unzip2(self.constvals.items())
t2v = lambda t: self.tracer_to_var[id(t)]
in_binders = constvars + [t2v(t) for t in in_tracers]
out_vars = [t2v(t) for t in out_tracers]
jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
typecheck_jaxpr(jaxpr)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
def _inline_literals(jaxpr: Jaxpr, consts: list[Any]) -> tuple[Jaxpr, list[Any]]:
const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
new_const_binders, lit_binders = partition_list(scalars, const_binders)
new_consts, lit_vals = partition_list(scalars, consts)
literals = dict(zip(lit_binders, map(Lit, lit_vals)))
new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
new_outs = [literals.get(x, x) for x in jaxpr.outs]
new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
typecheck_jaxpr(new_jaxpr)
return new_jaxpr, new_consts
我們針對 JaxprTrace.process_primitive
需要的規則本質上是原始運算的類型規則:給定原始運算、其參數和輸入的類型,規則必須產生輸出的類型,然後將其與輸出 JaxprTracer
打包在一起。我們可以將抽象求值規則用於相同的目的,即使它們可能更通用(因為抽象求值規則必須接受 ConcreteArray 輸入,並且由於它們只需要傳回可能輸出的集合的上限,它們也可以產生 ConcreteArray 輸出)。我們將為其他產生 jaxpr 的追蹤機制重複使用這些抽象求值規則,其中潛在的額外通用性很有用。
def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval
def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
raise TypeError
if x.shape != y.shape: raise TypeError
return [ShapedArray(x.shape, np.dtype('bool'))]
abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval
def vectorized_unop_abstract_eval(x: ShapedArray) -> list[ShapedArray]:
return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: tuple[int, ...]
) -> list[ShapedArray]:
axis_ = set(axis)
new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]
return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval
def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
axes: Sequence[int]) -> list[ShapedArray]:
return [ShapedArray(tuple(shape), x.dtype)]
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval
為了檢查我們 jaxpr 的實作,我們可以新增一個 make_jaxpr
轉換和一個美化列印器
from functools import lru_cache
@lru_cache # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in):
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
builder = JaxprBuilder()
with new_main(JaxprTrace, builder) as main:
trace = JaxprTrace(main)
tracers_in = [trace.new_arg(aval) for aval in avals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts = builder.build(tracers_in, tracers_out)
return jaxpr, consts, out_tree()
顯示程式碼單元來源
from collections import defaultdict
import string
class PPrint:
lines: list[tuple[int, str]]
def __init__(self, lines):
self.lines = lines
def indent(self, indent: int) -> 'PPrint':
return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])
def __add__(self, rhs: 'PPrint') -> 'PPrint':
return PPrint(self.lines + rhs.lines)
def __rshift__(self, rhs: 'PPrint') -> 'PPrint':
if not rhs.lines: return self
if not self.lines: return rhs
indent, s = self.lines[-1]
indented_block = rhs.indent(indent + len(s))
common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]
return PPrint(self.lines[:-1]
+ [(indent, common_line)]
+ indented_block.lines[1:])
def __str__(self) -> str:
return '\n'.join(' ' * indent + s for indent, s in self.lines)
def pp(s: Any) -> PPrint:
return PPrint([(0, line) for line in str(s).splitlines()])
def vcat(ps: list[PPrint]) -> PPrint:
return sum(ps, pp(''))
def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:
namegen = (''.join(s) for r in it.count(1)
for s in it.permutations(string.ascii_lowercase, r))
names = defaultdict(lambda: next(namegen))
in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)
eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])
outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val)
for v in jaxpr.outs)
return (pp(f'{{ lambda {in_binders} .') +
((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))
def var_str(names: defaultdict[Var, str], v: Var) -> str:
return f'{names[v]}:{v.aval.str_short()}'
def pp_eqn(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
rule = pp_rules.get(eqn.primitive)
if rule:
return rule(names, eqn)
else:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return lhs >> pp(' = ') >> rhs
def pp_params(params: dict[str, Any]) -> PPrint:
items = sorted(params.items())
if items:
return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')
else:
return pp(' ')
Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))
pp_rules: dict[Primitive, Callable[..., PPrint]] = {}
jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))
print(jaxpr)
print(typecheck_jaxpr(jaxpr))
{ lambda a:float64[] .
let b:float64[] = mul 2.0 a
in ( b ) }
(float64[]) -> (float64[])
但這裡有一個限制:由於 find_top_trace
依據資料依賴性運作的方式,make_jaxpr_v1
無法將 Python 可呼叫物件執行的所有原始運算分段輸出。例如
jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))
print(jaxpr)
{ lambda .
let
in ( 4.0 ) }
這正是 omnistaging 修復的問題。我們希望確保由 make_jaxpr
啟動的 JaxprTrace
始終被應用,無論 bind
的任何輸入是否被封裝在對應的 JaxprTracer
實例中。我們可以透過使用第 1 部分中定義的 dynamic_trace
全域變數來實現這一點
@contextmanager
def new_dynamic(main: MainTrace):
global dynamic_trace
prev_dynamic_trace, dynamic_trace = dynamic_trace, main
try:
yield
finally:
dynamic_trace = prev_dynamic_trace
@lru_cache
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
) -> tuple[Jaxpr, list[Any], PyTreeDef]:
avals_in, in_tree = tree_flatten(avals_in)
f, out_tree = flatten_fun(f, in_tree)
builder = JaxprBuilder()
with new_main(JaxprTrace, builder) as main:
with new_dynamic(main):
trace = JaxprTrace(main)
tracers_in = [trace.new_arg(aval) for aval in avals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
jaxpr, consts = builder.build(tracers_in, tracers_out)
return jaxpr, consts, out_tree()
jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))
print(jaxpr)
{ lambda .
let a:float64[] = mul 2.0 2.0
in ( a ) }
以這種方式使用 dynamic_trace
在概念上與暫存目前的直譯器堆疊,並在底部使用 JaxprTrace
啟動一個新的堆疊相同。也就是說,沒有比 dynamic_trace
更低的堆疊中的直譯器會被應用(因為 JaxprTrace.process_primitive
不會呼叫 bind
),但如果被追蹤到 jaxpr 的 Python 可呼叫物件本身使用轉換,那麼這些轉換可以被推送到 JaxprTrace
之上的直譯器堆疊中。但是暫時暫存直譯器堆疊會破壞系統狀態。dynamic_trace
標籤實現了相同的目標,同時保持系統狀態更簡單。
jaxpr 就到此為止!有了 jaxpr,我們可以實作剩餘的主要 JAX 功能。
第 3 部分:簡化的 jit
#
雖然 jit
具有類似轉換的 API,因為它接受 Python 可呼叫物件作為引數,但在底層它實際上是一個高階原語,而不是轉換。當原語由函數參數化時,它是高階的。
即時(“最終樣式”)和分段(“初始樣式”)處理#
有兩種選項可以處理高階原語。每種方法都需要不同的追蹤方法,並產生不同的權衡取捨
即時處理,其中
bind
接受 Python 可呼叫物件作為引數。 我們延遲形成 jaxpr,直到盡可能晚的時候,也就是直到我們在直譯器堆疊的底部執行最終直譯器時。這樣我們就可以在直譯器堆疊的底部交換JaxprTrace
,從而分段輸出而不是執行所有原始運算。透過這種方法,堆疊中的轉換會在我們像往常一樣執行 Python 可呼叫物件時被應用。這種方法實作起來可能非常棘手,但它盡可能通用,因為它允許高階原語不提高其引數的抽象層次,從而允許資料相關的 Python 控制流程。我們將這種方法稱為使用“最終樣式高階原語”,採用我們迄今為止使用的追蹤時卸載“最終樣式轉換”。分段處理,其中
bind
接受 jaxpr 作為引數。 在我們呼叫bind
之前,在原始包裝器中,我們可以僅使用make_jaxpr
預先形成一個 jaxpr,並完全完成 Python 可呼叫物件的工作。在這種情況下,make_jaxpr
將其JaxprTrace
放在直譯器堆疊的頂部,並且堆疊中較低的轉換(可能透過封閉的 Tracer 進入)不會應用於我們追蹤的 Python 可呼叫物件。(在 Python 可呼叫物件中應用的轉換會像往常一樣應用,並被新增到 JaxprTrace 之上的堆疊中。)相反,堆疊中較低的轉換稍後會應用於呼叫原語,並且呼叫原語的規則必須接著轉換 jaxpr 本身。因為我們預先追蹤到 jaxpr,所以這種方法無法支援資料相關的 Python 控制流程,但它更容易實作。我們將這種高階原語稱為“初始樣式高階原語”,並說它的 jaxpr 處理轉換規則是“初始樣式轉換規則”。
後者的方法適用於 jit
,因為我們不需要在使用者提供的 Python 可呼叫物件中支援資料相關的 Python 控制流程,因為 jit
的全部目的是將計算分段輸出 Python,以便由 XLA 執行。(相反,custom_jvp
是一個高階原語,我們希望在其中支援資料相關的 Python 控制流程。)
從歷史上看,我們在閱讀了 typed tagless final interpreters 論文後,開始使用“初始樣式”和“最終樣式”術語,並開玩笑地將 JAX 稱為“untyped tagful final interpreters”的實作。我們不聲稱要延續(或理解)這些術語背後的任何深刻含義;我們鬆散地使用“初始樣式”來表示“建構 AST 然後轉換它”,我們使用“最終樣式”來表示“在我們追蹤時轉換”。但這只是不精確但又黏著的術語。
使用初始樣式方法,以下是面向使用者的 jit
包裝器
def jit(f):
def f_jitted(*args):
avals_in = [raise_to_shaped(get_aval(x)) for x in args]
jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)
outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))
return tree_unflatten(out_tree, outs)
return f_jitted
xla_call_p = Primitive('xla_call')
對於任何新的原語,我們都需要為其提供轉換規則,從其求值規則開始。當我們評估 xla_call
原語的應用時,我們希望將計算分段輸出到 XLA。這涉及將 jaxpr 轉換為 XLA HLO 程式、將引數值傳輸到 XLA 裝置、執行 XLA 程式,然後將結果傳輸回來。我們將快取 XLA HLO 編譯,以便對於每個 jit
化的函數,它只需要針對每個引數形狀和 dtype 簽名執行一次。
首先,一些實用程式。
class IDHashable:
val: Any
def __init__(self, val):
self.val = val
def __hash__(self) -> int:
return id(self.val)
def __eq__(self, other):
return type(other) is IDHashable and id(self.val) == id(other.val)
接下來,我們將定義 xla_call
的求值規則
import io
from jax.extend.mlir import ir
from jax.extend.mlir.dialects import func
from jax.extend.mlir.dialects import stablehlo as hlo
from jax._src import xla_bridge as xb
class MlirContext(NamedTuple):
module: ir.Module
symbol_table: ir.SymbolTable
def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
consts, args = args[:num_consts], args[num_consts:]
hashable_consts = tuple(map(IDHashable, consts))
execute = xla_callable(IDHashable(jaxpr), hashable_consts)
return execute(*args)
impl_rules[xla_call_p] = xla_call_impl
@lru_cache
def xla_callable(hashable_jaxpr: IDHashable,
hashable_consts: tuple[IDHashable, ...]):
jaxpr: Jaxpr = hashable_jaxpr.val
typecheck_jaxpr(jaxpr)
consts = [x.val for x in hashable_consts]
in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
with ir.Context() as ctx, ir.Location.unknown(ctx):
hlo.register_dialect(ctx)
m = ir.Module.create()
c = MlirContext(m, ir.SymbolTable(m.operation))
with ir.InsertionPoint(c.module.body):
@func.func(*(aval_to_ir_type(aval) for aval in in_avals))
def main(*params):
return jaxpr_subcomp(c, jaxpr, _hlo_consts(consts) + params)
output = io.StringIO()
c.module.operation.print(file=output)
compiled = xb.get_backend(None).compile(output.getvalue())
return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])
def _mlir_dtype(dtype: np.dtype) -> ir.Type:
if np.issubdtype(dtype, np.signedinteger):
return ir.IntegerType.get_signless(np.iinfo(dtype).bits)
elif dtype == np.float32:
return ir.F32Type.get()
elif dtype == np.float64:
return ir.F64Type.get()
else:
raise NotImplementedError("MLIR conversion not implemented for ", dtype)
def aval_to_ir_type(aval: ShapedArray) -> ir.Type:
return ir.RankedTensorType.get(aval.shape, _mlir_dtype(aval.dtype))
def _hlo_const(x: Any) -> ir.Value:
a = np.asarray(x)
if a.dtype == np.bool_:
return hlo.constant(ir.DenseElementsAttr.get(
np.packbits(a, bitorder='little'), type=ir.IntegerType.get_signless(1),
shape=a.shape))
else:
return hlo.constant(ir.DenseElementsAttr.get(a))
def _hlo_consts(consts: list[Any]) -> list[ir.Value]:
unique_consts = {id(cnst): cnst for cnst in consts}
ir_consts = {id_: _hlo_const(cnst) for id_, cnst in unique_consts.items()}
return tuple(ir_consts[id(cnst)] for cnst in consts)
主要動作在 xla_callable
中,它使用 jaxpr_subcomp
將 jaxpr 編譯為 XLA HLO 程式,然後傳回一個可呼叫物件,該物件執行編譯後的程式
def jaxpr_subcomp(c: MlirContext, jaxpr: Jaxpr, args: list[ir.Value]) -> list[ir.Value]:
env: dict[Var, ir.Value] = {}
def read(x: Atom) -> ir.Value:
return env[x] if type(x) is Var else _hlo_const(np.asarray(x.val))
def write(v: Var, val: ir.Value) -> None:
env[v] = val
map(write, jaxpr.in_binders, args)
for eqn in jaxpr.eqns:
in_avals = [x.aval for x in eqn.inputs]
in_vals = map(read, eqn.inputs)
out_avals = [x.aval for x in eqn.out_binders]
rule = hlo_translations[eqn.primitive]
assert all(isinstance(v, ir.Value) for v in in_vals), in_vals
out_vals = rule(c, in_avals, out_avals, in_vals, **eqn.params)
assert all(isinstance(v, ir.Value) for v in out_vals), out_vals
map(write, eqn.out_binders, out_vals), out_vals
return map(read, jaxpr.outs)
def execute_compiled(compiled, out_avals, *args):
input_bufs = [input_handlers[type(x)](x) for x in args]
out_bufs = compiled.execute(input_bufs)
return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]
default_input_handler = xb.get_backend(None).buffer_from_pyval
input_handlers = {ty: default_input_handler for ty in
[bool, int, float, np.ndarray, np.float64, np.float32]}
def handle_result(aval: ShapedArray, buf):
del aval # Unused for now
return np.asarray(buf)
hlo_translations = {}
請注意,jaxpr_subcomp
具有簡單直譯器的結構。這是一個常見的模式:我們處理 jaxpr 的方式通常是使用直譯器。與任何直譯器一樣,我們需要每個原語的解釋規則
def direct_translation(op, c, in_avals, out_avals, in_vals):
del c, in_avals, out_avals
return [op(*in_vals)]
hlo_translations[add_p] = partial(direct_translation, hlo.add)
hlo_translations[mul_p] = partial(direct_translation, hlo.multiply)
hlo_translations[neg_p] = partial(direct_translation, hlo.negate)
hlo_translations[sin_p] = partial(direct_translation, hlo.sine)
hlo_translations[cos_p] = partial(direct_translation, hlo.cosine)
def compare_translation(op, c, in_avals, out_avals, in_vals):
del c, out_avals
return [hlo.compare(*in_vals, hlo.ComparisonDirectionAttr.get(op))]
hlo_translations[greater_p] = partial(compare_translation, "GT")
hlo_translations[less_p] = partial(compare_translation, "LT")
def reduce_sum_translation(c, in_avals, out_avals, in_vals, *, axis):
del c
(x_aval,), (out_aval,), (x,) = in_avals, out_avals, in_vals
op = hlo.ReduceOp(
[aval_to_ir_type(out_aval)], [x], [_hlo_const(np.array(0, x_aval.dtype))],
axis)
scalar_type = aval_to_ir_type(ShapedArray((), x_aval.dtype))
reducer_region = op.body.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_region):
hlo.return_([hlo.add(*reducer_region.arguments)])
return op.results
hlo_translations[reduce_sum_p] = reduce_sum_translation
def broadcast_translation(c, in_avals, out_avals, in_vals, *, shape, axes):
del c
(x,), (out_aval,) = in_vals, out_avals
dims_complement = [i for i in range(len(shape)) if i not in axes]
return [hlo.broadcast_in_dim(aval_to_ir_type(out_aval), x, dims_complement)]
hlo_translations[broadcast_p] = broadcast_translation
有了這個,我們現在可以使用 jit
來分段輸出、編譯和執行使用 XLA 的程式!
@jit
def f(x, y):
print('tracing!')
return sin(x) * cos(y)
z = f(3., 4.) # 'tracing!' prints the first time
print(z)
tracing!
-0.09224219304455371
z = f(4., 5.) # 'tracing!' doesn't print, compilation cache hit!
print(z)
-0.21467624978306993
@jit
def f(x):
return reduce_sum(x, axis=0)
print(f(np.array([1., 2., 3.])))
6.0
def f(x):
y = sin(x) * 2.
z = - y + x
return z
def deriv(f):
return lambda x: jvp(f, (x,), (1.,))[1]
print( deriv(deriv(f))(3.))
print(jit(deriv(deriv(f)))(3.))
0.2822400161197344
0.2822400161197344
與其實作 jit
以首先追蹤到 jaxpr,然後將 jaxpr 降級為 XLA HLO,似乎我們可以跳過 jaxpr 步驟,並在追蹤時直接降級為 HLO。也就是說,也許我們可以改為使用 Trace
和 Tracer
來實作 jit
,它們在每個原始綁定上增量地附加到 XLA HLO 圖形。目前來說這是正確的,但當我們引入編譯後的 SPMD 計算時,這將是不可能的,因為在那裡我們必須在編譯程式之前知道所需的副本數量。
除了其求值規則之外,我們尚未定義 xla_call_p
的任何轉換規則。也就是說,我們還不能做 vmap
-of-jit
或 jvp
-of-jit
甚至 jit
-of-jit
。相反,jit
必須處於“頂層”。讓我們修正這個問題!
def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
del num_consts # Unused
new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
num_consts=len(new_consts))
n = len(outs) // 2
primals_out, tangents_out = outs[:n], outs[n:]
return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule
@lru_cache
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
def jvp_traceable(*primals_and_tangents):
n = len(primals_and_tangents) // 2
primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
return jvp(jaxpr_as_fun(jaxpr), primals, tangents)
in_avals = [v.aval for v in jaxpr.in_binders]
new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)
return new_jaxpr, new_consts
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
del num_consts # Unused
new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
num_consts=len(new_consts))
return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule
@lru_cache
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
) -> tuple[Jaxpr, list[Any]]:
vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
in_avals = [unmapped_aval(axis_size, d, v.aval)
for v, d in zip(jaxpr.in_binders, bdims_in)]
new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)
return new_jaxpr, new_consts
def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
) -> ShapedArray:
if batch_dim is not_mapped:
return aval
else:
shape = list(aval.shape)
shape.insert(batch_dim, axis_size)
return ShapedArray(tuple(shape), aval.dtype)
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
del num_consts # Unused
jaxpr_type = typecheck_jaxpr(jaxpr)
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule
def xla_call_translation(c, in_avals, out_avals, in_vals, *, jaxpr, num_consts):
del num_consts, out_avals
# Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.
with ir.InsertionPoint(c.module.body):
@func.func(*(aval_to_ir_type(aval) for aval in in_avals))
def inner_xla_call(*params):
return jaxpr_subcomp(c, jaxpr, params)
name = c.symbol_table.insert(inner_xla_call.func_op)
return func.CallOp(inner_xla_call.func_op, in_vals).results
hlo_translations[xla_call_p] = xla_call_translation
@jit
def f(x):
print('tracing!')
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
tracing!
2.7177599838802657
2.979984993200891
y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed
ys = vmap(f, (0,))(np.arange(3.))
print(ys)
[ 0. -0.68294197 0.18140515]
缺少的一個部分是陣列的裝置記憶體持久性。也就是說,我們已經定義了 handle_result
將結果作為 NumPy 陣列傳輸回 CPU 記憶體,但通常最好避免傳輸結果,只是為了在下一個運算中將它們傳輸回去。我們可以透過引入 Array
類別來做到這一點,它可以包裝 XLA 緩衝區,並以其他方式進行 duck-type numpy.ndarray
。
def handle_result(aval: ShapedArray, buf): # noqa: F811
return Array(aval, buf)
class Array:
buf: Any
aval: ShapedArray
def __init__(self, aval, buf):
self.aval = aval
self.buf = buf
dtype = property(lambda self: self.aval.dtype)
shape = property(lambda self: self.aval.shape)
ndim = property(lambda self: self.aval.ndim)
def __array__(self): return np.asarray(self.buf)
def __repr__(self): return repr(np.asarray(self.buf))
def __str__(self): return str(np.asarray(self.buf))
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(add)
_mul = staticmethod(mul)
_rmul = staticmethod(mul)
_gt = staticmethod(greater)
_lt = staticmethod(less)
input_handlers[Array] = lambda x: x.buf
jax_types.add(Array)
@jit
def f(x):
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
顯示程式碼單元來源
def pprint_xla_call(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}
rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return vcat([lhs >> pp(' = ') >> rhs,
pp_jaxpr(eqn.params['jaxpr']).indent(2)])
pp_rules[xla_call_p] = pprint_xla_call
第 4 部分:linearize
和 vjp
(和 grad
!)#
linearize
和 vjp
自動微分函數建立在 jvp
之上,但也涉及 jaxpr。這是因為兩者都涉及分段輸出或延遲計算。
linearize
#
在 linearize
的情況下,我們想要分段輸出 jvp
計算的線性部分。也就是說,用 類似 Haskell 的類型簽名來說,如果我們有 jvp : (a -> b) -> (a, T a) -> (b, T b)
,那麼我們寫 linearize : (a -> b) -> a -> (b, T a -o T b)
,使用 T a
表示“a
的切線類型”,並使用“棒棒糖”-o
而不是箭頭 ->
來表示線性函數。我們也根據 jvp
定義 linearize
的語義
y, f_lin = linearize(f, x)
y_dot = f_lin(x_dot)
對於 (y, y_dot)
給出與以下相同的結果
y, y_dot = jvp(f, (x,), (x_dot,))
其中 f_lin
的應用不會重做任何線性化工作。我們將延遲的線性部分 f_lin : T a -o T b
表示為 jaxpr。
順便一提,既然我們有了線性箭頭 -o
,我們可以為 jvp
提供稍微更具資訊性的類型
jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)
這裡我們寫 UnrestrictedUse
只是為了表明我們有一個特殊的配對,其中第一個元素可以以不受限制(非線性)的方式使用。結合線性箭頭,這種表示法只是為了表達函數 jvp f
以非線性的方式使用其第一個輸入,但以線性的方式使用其第二個輸入,產生對應的非線性輸出(可以以非線性的方式使用)並與線性輸出配對。這個更精細的類型簽名編碼了 jvp f
中的資料依賴性,這對於部分求值很有用。
為了從 JVP 建構 f_lin
jaxpr,我們需要執行部分求值:我們在追蹤時評估所有原始值,但將切線計算分段到 jaxpr 中。這是我們建構 jaxpr 的第二種方式。但是,make_jaxpr
及其底層的 JaxprTrace
/JaxprTracer
直譯器旨在分段輸出每個原始綁定,而第二種方法僅分段輸出那些對切線輸入具有資料依賴性的原始綁定。
首先,一些實用程式
def split_half(lst: list[Any]) -> tuple[list[Any], list[Any]]:
assert not len(lst) % 2
return split_list(lst, len(lst) // 2)
def merge_lists(which: list[bool], l1: list[Any], l2: list[Any]) -> list[Any]:
l1, l2 = iter(l1), iter(l2)
out = [next(l2) if b else next(l1) for b in which]
assert next(l1, None) is next(l2, None) is None
return out
接下來,我們將透過將 jvp
與一般部分求值轉換結合起來編寫 linearize
,接下來將新增
def linearize_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
return primals_out, f_lin
def linearize(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_lin(*tangents_in):
tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
if in_tree != in_tree2: raise TypeError
tangents_out_flat = f_lin_flat(*tangents_in_flat)
return tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, f_lin
def vspace(aval: ShapedArray) -> ShapedArray:
return raise_to_shaped(aval) # TODO handle integers?
現在我們轉向一般部分求值轉換。目標是接受一個 Python 可呼叫物件和一個輸入列表,其中一些是已知的,一些是未知的,並產生 (1) 可以從已知輸入計算的所有輸出,以及 (2) 一個 jaxpr,表示 Python 可呼叫物件的計算部分,該部分只能在剩餘輸入已知後執行。
這種轉換很難在類型簽名中總結。如果我們假設輸入函數的類型簽名是 (a1, a2) -> (b1, b2)
,其中 a1
和 a2
分別表示已知和未知輸入,並且其中 b1
僅對 a1
具有資料依賴性,而 b2
對 a2
具有一些資料依賴性,那麼我們可以寫
partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)
換句話說,給定類型為 a1
的輸入值,partial_eval
產生類型為 b1
的輸出,以及存在量化類型 r
的“殘餘”值,表示在第二階段完成計算所需的中間值。它還產生類型為 (r, a2) -> b2
的函數,該函數接受殘餘值以及剩餘輸入,並產生剩餘輸出。
我們喜歡將部分求值視為將一個計算“解壓縮”為兩個。例如,考慮這個 jaxpr
{ lambda a:float64[] .
let b:float64[] = sin a
c:float64[] = neg b
in ( c ) }
JVP 的 jaxpr 看起來像
{ lambda a:float64[] b:float64[] .
let c:float64[] = sin a
d:float64[] = cos a
e:float64[] = mul d b
f:float64[] = neg c
g:float64[] = neg e
in ( f, g ) }
如果我們想像對這個 jaxpr 應用部分求值,其中第一個輸入已知,第二個輸入未知,我們最終會將 JVP jaxpr“解壓縮”為原始 jaxpr 和切線 jaxpr
{ lambda a:float64[] .
let c:float64[] = sin a
d:float64[] = cos a
f:float64[] = neg c
in ( f, d ) }
{ lambda d:float64[] b:float64[] .
let e:float64[] = mul d b
g:float64[] = neg e
in ( g ) }
第二個 jaxpr 表示我們從 linearize
想要的線性計算。
但是,與這個 jaxpr 範例不同,我們希望在評估輸入 Python 可呼叫物件時執行對已知值的計算。也就是說,與其為整個函數 (a1, a2) -> (b1, b2)
形成 jaxpr,首先將所有運算分段輸出 Python,然後再整理出現在可以評估的內容和必須延遲的內容,我們只想為那些必須延遲的運算形成 jaxpr,因為它們依賴於未知輸入。在自動微分的上下文中,這個功能最終使我們能夠處理諸如 grad(lambda x: x**2 if x > 0 else 0.)
之類的函數。Python 控制流程之所以有效,是因為部分求值將原始計算保留在 Python 中。因此,我們的 Trace
和 Tracer
子類必須即時整理出可以評估的內容和必須分段輸出到 jaxpr 中的內容。
首先,我們從 PartialVal
類別開始,它表示可以是已知或未知的值
class PartialVal(NamedTuple):
aval: ShapedArray
const: Any | None
@classmethod
def known(cls, val: Any):
return PartialVal(get_aval(val), val)
@classmethod
def unknown(cls, aval: ShapedArray):
return PartialVal(aval, None)
is_known = property(lambda self: self.const is not None)
is_unknown = property(lambda self: self.const is None)
部分求值將採用表示輸入的 PartialVal
列表,並傳回 PartialVal
輸出列表以及表示延遲計算的 jaxpr
def partial_eval_flat(f: Callable, pvals_in: list[PartialVal]
) -> tuple[Jaxpr, list[PartialVal], list[Any]]:
with new_main(PartialEvalTrace) as main:
trace = PartialEvalTrace(main)
tracers_in = [trace.new_arg(pval) for pval in pvals_in]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
pvals_out = [t.pval for t in tracers_out]
unk_tracers_in = [t for t in tracers_in if t.pval.is_unknown]
unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)
return jaxpr, pvals_out, consts
接下來,我們需要實作 PartialEvalTrace
及其 PartialEvalTracer
。這個直譯器將在追蹤資料依賴性的同時即時建構 jaxpr。為此,它在 PartialEvalTracer
節點(表示分段輸出的值)和 JaxprRecipe
節點(表示如何從其他值計算某些值的公式)之間建構一個二分有向無環圖 (DAG)。一種配方是 JaxprEqnRecipe
,對應於 JaxprEqn
的原始運算應用,但我們也有用於常數和 lambda 綁定器的配方類型
from weakref import ref, ReferenceType
class LambdaBindingRecipe(NamedTuple):
pass
class ConstRecipe(NamedTuple):
val: Any
class JaxprEqnRecipe(NamedTuple):
prim: Primitive
tracers_in: list['PartialEvalTracer']
params: dict[str, Any]
avals_out: list[ShapedArray]
tracer_refs_out: list['ReferenceType[PartialEvalTracer]']
JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer):
pval: PartialVal
recipe: JaxprRecipe | None
def __init__(self, trace, pval, recipe):
self._trace = trace
self.pval = pval
self.recipe = recipe
aval = property(lambda self: self.pval.aval)
def full_lower(self):
if self.pval.is_known:
return full_lower(self.pval.const)
return self
PartialEvalTrace
包含建構 JaxprRecipe
和 PartialEvalTracer
圖形的邏輯。每個引數對應於一個 LambdaBindingRecipe
葉節點,每個常數都是一個 ConstRecipe
葉節點,其中包含對常數的引用。所有其他追蹤器和配方都來自 process_primitive
,它使用 JaxprEqnRecipe
形成追蹤器。
對於大多數原語,process_primitive
邏輯很簡單:如果所有輸入都是已知的,那麼我們可以在已知值上綁定原語(在 Python 中評估它)並避免形成對應於輸出的追蹤器。相反,如果任何輸入是未知的,那麼我們會分段輸出到表示原始運算應用的 JaxprEqnRecipe
中。為了建構表示未知輸出的追蹤器,我們需要 avals,我們從抽象評估規則中獲得它們。(請注意,追蹤器引用 JaxprEqnRecipe
,而 JaxprEqnRecipe
引用追蹤器;我們透過使用 weakref
避免循環垃圾回收。)
process_primitive
邏輯適用於大多數原語,但 xla_call_p
需要遞迴處理。因此,我們在 partial_eval_rules
字典中特殊處理其規則。
class PartialEvalTrace(Trace):
def new_arg(self, pval: PartialVal) -> Any:
return PartialEvalTracer(self, pval, LambdaBindingRecipe())
def lift(self, val: Any) -> PartialEvalTracer:
return PartialEvalTracer(self, PartialVal.known(val), None)
pure = lift
def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
if tracer.pval.is_unknown:
return tracer
else:
pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))
def process_primitive(self, primitive, tracers, params):
if all(t.pval.is_known for t in tracers):
return bind(primitive, *map(full_lower, tracers), **params)
rule = partial_eval_rules.get(primitive)
if rule: return rule(self, tracers, **params)
tracers_in = [self.instantiate_const(t) for t in tracers]
avals_in = [t.aval for t in tracers_in]
avals_out = abstract_eval_rules[primitive](*avals_in, **params)
tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
for aval in avals_out]
eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
map(ref, tracers_out))
for t in tracers_out: t.recipe = eqn
return tracers_out
partial_eval_rules = {}
既然我們可以使用 PartialEvalTrace
建構 jaxpr 的圖形表示,我們需要一種機制將圖形表示轉換為標準 jaxpr。jaxpr 對應於圖形的拓撲排序。
def tracers_to_jaxpr(tracers_in: list[PartialEvalTracer],
tracers_out: list[PartialEvalTracer]):
tracer_to_var: dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
for t in tracers_in}
constvar_to_val: dict[int, Any] = {}
constid_to_var: dict[int, Var] = {}
processed_eqns: set[int] = set()
eqns: list[JaxprEqn] = []
for t in toposort(tracers_out, tracer_parents):
if isinstance(t.recipe, LambdaBindingRecipe):
assert id(t) in set(map(id, tracers_in))
elif isinstance(t.recipe, ConstRecipe):
val = t.recipe.val
var = constid_to_var.get(id(val))
if var is None:
aval = raise_to_shaped(get_aval(val))
var = constid_to_var[id(val)] = Var(aval)
constvar_to_val[var] = val
tracer_to_var[id(t)] = var
elif isinstance(t.recipe, JaxprEqnRecipe):
if id(t.recipe) not in processed_eqns:
eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
processed_eqns.add(id(t.recipe))
else:
raise TypeError(t.recipe)
constvars, constvals = unzip2(constvar_to_val.items())
in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
out_vars = [tracer_to_var[id(t)] for t in tracers_out]
jaxpr = Jaxpr(in_binders, eqns, out_vars)
typecheck_jaxpr(jaxpr)
return jaxpr, constvals
def recipe_to_eqn(tracer_to_var: dict[int, Var], recipe: JaxprEqnRecipe
) -> JaxprEqn:
inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
out_binders = [Var(aval) for aval in recipe.avals_out]
for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
if t_ref() is not None: tracer_to_var[id(t_ref())] = var
return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)
def tracer_parents(t: PartialEvalTracer) -> list[PartialEvalTracer]:
return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
顯示程式碼單元來源
def toposort(out_nodes: list[Any], parents: Callable[[Any], list[Any]]):
if not out_nodes: return []
out_nodes = remove_duplicates(out_nodes)
child_counts = {}
stack = list(out_nodes)
while stack:
node = stack.pop()
if id(node) in child_counts:
child_counts[id(node)] += 1
else:
child_counts[id(node)] = 1
stack.extend(parents(node))
for node in out_nodes:
child_counts[id(node)] -= 1
sorted_nodes = []
childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
while childless_nodes:
node = childless_nodes.pop()
sorted_nodes.append(node)
for parent in parents(node):
if child_counts[id(parent)] == 1:
childless_nodes.append(parent)
else:
child_counts[id(parent)] -= 1
sorted_nodes = sorted_nodes[::-1]
check_toposort(sorted_nodes, parents)
return sorted_nodes
def remove_duplicates(lst):
seen = set()
return [x for x in lst if id(x) not in seen and not seen.add(id(x))]
def check_toposort(nodes: list[Any], parents: Callable[[Any], list[Any]]):
seen = set()
for node in nodes:
assert all(id(parent) in seen for parent in parents(node))
seen.add(id(node))
現在我們可以線性化了!
y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.))
0.1411200080598672 0.1411200080598672
-0.9899924966004454 -0.9899924966004454
為了處理 linearize
-of-jit
,我們仍然需要為 xla_call_p
編寫部分求值規則。除了追蹤器簿記之外,主要任務是對 jaxpr 執行部分求值,將其“解壓縮”為兩個 jaxpr。
實際上需要編寫兩個規則:一個用於追蹤時的部分求值,我們將其稱為 xla_call_partial_eval
,另一個用於 jaxpr 的部分求值,我們將其稱為 xla_call_peval_eqn
。
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
del num_consts # Unused
in_unknowns = [not t.pval.is_known for t in tracers]
jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)
outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in jaxpr2.outs]
eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,
dict(jaxpr=jaxpr2, num_consts=0),
[v.aval for v in jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
instantiate: list[bool] | None = None,
) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
env: dict[Var, bool] = {}
residuals: set[Var] = set()
def read(x: Atom) -> bool:
return type(x) is Var and env[x]
def write(unk: bool, v: Var) -> None:
env[v] = unk
def new_res(x: Atom) -> Atom:
if type(x) is Var: residuals.add(x)
return x
eqns1, eqns2 = [], []
map(write, in_unknowns, jaxpr.in_binders)
for eqn in jaxpr.eqns:
unks_in = map(read, eqn.inputs)
rule = partial_eval_jaxpr_rules.get(eqn.primitive)
if rule:
eqn1, eqn2, unks_out, res = rule(unks_in, eqn)
eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)
map(write, unks_out, eqn.out_binders)
elif any(unks_in):
inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]
eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))
map(partial(write, True), eqn.out_binders)
else:
eqns1.append(eqn)
map(partial(write, False), eqn.out_binders)
out_unknowns = map(read, jaxpr.outs)
if instantiate is not None:
for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
if inst and not uk: new_res(v)
out_unknowns = map(op.or_, out_unknowns, instantiate)
residuals, num_res = list(residuals), len(residuals)
assert all(type(v) is Var for v in residuals), residuals
ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)
jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)
jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)
typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)
return jaxpr1, jaxpr2, out_unknowns, num_res
def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
jaxprty = typecheck_jaxpr(jaxpr) # (a1, a2) -> (b1, b2 )
jaxpr1ty = typecheck_jaxpr(jaxpr1) # a1 -> (b1, res)
jaxpr2ty = typecheck_jaxpr(jaxpr2) # (res, a2) -> b2
a1, a2 = partition_list(unks_in, jaxprty.in_types)
b1, b2 = partition_list(unks_out, jaxprty.out_types)
b1_, res = split_list(jaxpr1ty.out_types, len(b1))
res_, a2_ = split_list(jaxpr2ty.in_types, len(res))
b2_ = jaxpr2ty.out_types
if jaxpr1ty.in_types != a1: raise TypeError
if jaxpr2ty.out_types != b2: raise TypeError
if b1 != b1_: raise TypeError
if res != res_: raise TypeError
if a2 != a2_: raise TypeError
if b2 != b2_: raise TypeError
partial_eval_jaxpr_rules = {}
def xla_call_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Var]]:
jaxpr = eqn.params['jaxpr']
jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
ins1, ins2 = partition_list(unks_in, eqn.inputs)
out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
out_binders1 + residuals)
eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn
有了這個,我們可以隨意組合 linearize
和 jit
@jit
def f(x):
y = sin(x) * 2.
z = - y + x
return z
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
2.7177599838802657 2.979984993200891
@jit
def f(x):
y = sin(x) * 2.
z = g(x, y)
return z
@jit
def g(x, y):
return cos(x) + y
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
-0.7077524804807109 -2.121105001260758
vjp
和 grad
#
vjp
轉換的工作方式與 linearize 非常相似。其類型簽名是類似的
linearize : (a -> b) -> a -> (b, T a -o T b)
vjp : (a -> b) -> a -> (b, T b -o T a)
唯一的區別是我們在傳回線性計算部分之前先對其進行轉置,使其從類型 T a -o T b
變為類型 T b -o T a
。也就是說,我們將 vjp
實作為,基本上,
def vjp(f, x):
y, f_lin = linearize(f, x)
f_vjp = lambda y_bar: transpose(f_lin)(y_bar)
return y, f_vjp
由於我們將線性計算作為 jaxpr 而不是僅僅作為 Python 可呼叫物件,因此我們可以將轉置轉換實作為 jaxpr 直譯器。
def vjp_flat(f, *primals_in):
pvals_in = ([PartialVal.known(x) for x in primals_in] +
[PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
def f_jvp(*primals_tangents_in):
primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
return [*primals_out, *tangents_out]
jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) # linearize
primal_pvals, _ = split_half(pvals_out)
assert all(pval.is_known for pval in primal_pvals)
primals_out = [pval.const for pval in primal_pvals]
transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]
f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)
return primals_out, f_vjp
def vjp(f, *primals_in):
primals_in_flat, in_tree = tree_flatten(primals_in)
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
def f_vjp(*cotangents_out):
cotangents_out_flat, _ = tree_flatten(cotangents_out)
cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
return tree_unflatten(in_tree, cotangents_in_flat)
return primals_out, f_vjp
class UndefPrimal(NamedTuple):
aval: ShapedArray
register_pytree_node(UndefPrimal,
lambda u: (u.aval, ()),
lambda aval, _: UndefPrimal(aval))
我們使用 UndefPrimal
實例來指示我們要轉置哪些引數。這些之所以出現是因為一般來說,為了明確關於封閉值,我們想要將類型為 a -> b -o c
的函數轉置為類型為 a -> c -o b
的函數。更一般來說,函數在其中線性的輸入可以分散在引數列表中。因此,我們使用 UndefPrimal
指示線性位置。我們將 UndefPrimal
註冊為 pytree 節點,因為 pytree 機制提供了一種方便的方法來從引數列表中修剪掉這些佔位符。
接下來,我們可以編寫 eval_jaxpr_transposed
,以及所有至少在一個引數中可以是線性的原語的轉置規則
# NB: the analogous function in JAX is called 'backward_pass'
def eval_jaxpr_transposed(jaxpr: Jaxpr, args: list[Any], cotangents: list[Any]
) -> list[Any]:
primal_env: dict[Var, Any] = {}
ct_env: dict[Var, Any] = {}
def read_primal(x: Atom) -> Any:
return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val
def write_primal(v: Var, val: Any) -> None:
if type(val) is not UndefPrimal:
primal_env[v] = val
def read_cotangent(v: Var) -> Any:
return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))
def write_cotangent(x: Atom, val: Any):
if type(x) is Var and val is not None:
ct_env[x] = add(ct_env[x], val) if x in ct_env else val
map(write_primal, jaxpr.in_binders, args)
map(write_cotangent, jaxpr.outs, cotangents)
for eqn in jaxpr.eqns[::-1]:
primals_in = map(read_primal, eqn.inputs)
cts_in = map(read_cotangent, eqn.out_binders)
rule = transpose_rules[eqn.primitive]
cts_out = rule(cts_in, *primals_in, **eqn.params)
map(write_cotangent, eqn.inputs, cts_out)
return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)
if type(x) is UndefPrimal]
transpose_rules = {}
def mul_transpose_rule(cts, x, y):
z_bar, = cts
assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)
return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]
transpose_rules[mul_p] = mul_transpose_rule
def neg_transpose_rule(cts, x):
ybar, = cts
assert type(x) is UndefPrimal
return [neg(ybar)]
transpose_rules[neg_p] = neg_transpose_rule
def add_transpose_rule(cts, x, y):
z_bar, = cts
return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule
def reduce_sum_transpose_rule(cts, x, *, axis):
y_bar, = cts
return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule
def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
del num_consts # Unused
undef_primals = [type(x) is UndefPrimal for x in invals]
transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
residuals, _ = partition_list(undef_primals, invals)
outs = bind(xla_call_p, *new_consts, *residuals, *cts,
jaxpr=transposed_jaxpr, num_consts=len(new_consts))
outs = iter(outs)
return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
) -> tuple[Jaxpr, list[Any]]:
avals_in, avals_out = typecheck_jaxpr(jaxpr)
traceable = partial(eval_jaxpr_transposed, jaxpr)
args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
typecheck_jaxpr(trans_jaxpr)
return trans_jaxpr, consts
既然我們可以線性化和轉置,我們終於可以編寫 grad
def grad(f):
def gradfun(x, *xs):
y, f_vjp = vjp(f, x, *xs)
if np.shape(y) != (): raise TypeError
x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))
return x_bar
return gradfun
y, f_vjp = vjp(sin, 3.)
print(f_vjp(1.), cos(3.))
(np.float64(-0.9899924966004454),) -0.9899924966004454
def f(x):
y = sin(x) * 2.
z = - y + x
return z
print(grad(f)(3.))
2.979984993200891
@jit
def f(x):
y = x * 2.
z = g(y)
return z
@jit
def g(x):
return cos(x) * 2.
print(grad(f)(3.))
1.1176619927957034
這是一個組合性壓力測試
# from core_test.py fun_with_nested_calls_2
def foo(x):
@jit
def bar(y):
def baz(w):
q = jit(lambda x: y)(x)
q = q + jit(lambda: y)()
q = q + jit(lambda y: w + y)(y)
q = jit(lambda w: jit(sin)(x) * y)(1.0) + q
return q
p, t = jvp(baz, (x + 1.0,), (y,))
return t + (x * p)
return bar(x)
def assert_allclose(*vals):
for v1, v2 in zip(vals[:-1], vals[1:]):
np.testing.assert_allclose(v1, v2)
ans1 = f(3.)
ans2 = jit(f)(3.)
ans3, _ = jvp(f, (3.,), (5.,))
ans4, _ = jvp(jit(f), (3.,), (5.,))
assert_allclose(ans1, ans2, ans3, ans4)
deriv1 = grad(f)(3.)
deriv2 = grad(jit(f))(3.)
deriv3 = jit(grad(jit(f)))(3.)
_, deriv4 = jvp(f, (3.,), (1.,))
_, deriv5 = jvp(jit(f), (3.,), (1.,))
assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5)
hess1 = grad(grad(f))(3.)
hess2 = grad(grad(jit(f)))(3.)
hess3 = grad(jit(grad(f)))(3.)
hess4 = jit(grad(grad(f)))(3.)
_, hess5 = jvp(grad(f), (3.,), (1.,))
_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
第 5 部分:控制流程原語 cond
#
接下來,我們將為分段輸出的控制流程新增高階原語。這些原語類似於第 3 部分中的 jit
(另一個高階原語),但不同之處在於它們由多個可呼叫物件而不是僅一個參數化。
新增 cond
#
我們引入一個 cond
原語來表示 jaxpr 內部一個函數或另一個函數的條件應用。我們將 cond
的類型寫為 Bool -> (a -> b) -> (a -> b) -> a -> b
。換句話說,cond
接受一個布林值(表示謂詞)和兩個類型相等的函數。根據謂詞的值,它將一個函數或另一個函數應用於其最終引數。
在 Python 中,我們將其表示為一個函數,該函數本身將兩個函數作為引數。與 jit
一樣,第一步是在其可呼叫引數上呼叫 make_jaxpr
,以將它們轉換為 jaxpr
def cond(pred, true_fn, false_fn, *operands):
avals_in = [raise_to_shaped(get_aval(x)) for x in operands]
true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)
false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)
if out_tree != out_tree_: raise TypeError
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):
raise TypeError
outs = bind_cond(pred, *true_consts, *false_consts, *operands,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return tree_unflatten(out_tree, outs)
cond_p = Primitive('cond')
def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]
consts1, rest1 = split_list(jaxpr1.in_binders, n1)
consts2, rest2 = split_list(jaxpr2.in_binders, n2)
new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)
new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)
return new_jaxpr1, new_jaxpr2
def bind_cond(pred, *args, true_jaxpr, false_jaxpr):
assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)
return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
我們要求 true_jaxpr
和 false_jaxpr
具有相同的類型,但由於它們可能會封閉不同的常數(並且由於 jaxpr 只能表示封閉項,即不能有自由變數,而是進行了閉包轉換),我們需要使用輔助函數 _join_jaxpr_consts
來使兩個 jaxpr 的輸入綁定器列表保持一致。(為了更經濟,我們可以嘗試識別具有相同形狀的常數對,但我們只是連接常數列表。)
接下來,我們可以轉向為 cond
新增直譯器規則。其求值規則很簡單
def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):
if pred:
return eval_jaxpr(true_jaxpr, operands)
else:
return eval_jaxpr(false_jaxpr, operands)
impl_rules[cond_p] = cond_impl
out = cond(True, lambda: 3, lambda: 4)
print(out)
3
對於其 JVP 和 vmap 規則,我們只需要呼叫我們為 jit
建立的相同 jvp_jaxpr
和 vmap_jaxpr
實用程式,然後再執行一次 _join_jaxpr_consts
def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):
pred, *primals = primals
_ , *tangents = tangents
true_jaxpr , true_consts = jvp_jaxpr(true_jaxpr)
false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
primals_out, tangents_out = split_half(outs)
return primals_out, tangents_out
jvp_rules[cond_p] = cond_jvp_rule
out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))
print(out_tan)
2.0
def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):
pred , *vals_in = vals_in
pred_dim, *dims_in = dims_in
if pred_dim is not not_mapped: raise NotImplementedError # TODO
true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))
false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
return outs, [0] * len(outs)
vmap_rules[cond_p] = cond_vmap_rule
xs = np.array([1., 2., 3])
out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)
print(out)
[2. 3. 4.]
請注意,我們目前不支援謂詞值本身是批次處理的情況。在主線 JAX 中,我們透過將條件轉換為 select 原語 來處理這種情況。只要 true_fun
和 false_fun
不涉及任何副作用原語,這種轉換在語義上是正確的。
這裡沒有表示的另一件事,但在主線 JAX 中存在,是將轉換應用於兩個類型相等的 jaxpr 可能會導致類型不同的 jaxpr。例如,將主線 JAX 版本的 vmap_jaxpr
應用於恆等函數 jaxpr
{ lambda a:float32[] .
let
in ( a ) }
若批次大小為 10,則會產生一個具有批次輸出的 jaxpr,其類型為 [float32[10]] -> [float32[10]]
;當應用於零函數 jaxpr 時
{ lambda a:float32[] .
let
in ( 0. ) }
會產生一個具有非批次輸出的 jaxpr,其類型為 [float32[10]] -> [float32[]]
。這是一種最佳化,旨在避免不必要地批次處理數值。但這意味著在 cond
中,我們需要額外一步來合併兩個轉換後的 jaxpr,以使其輸出類型一致。我們在這裡不需要這個步驟,因為我們選擇 vmap_jaxpr
始終在領先軸上批次處理所有輸出。
接下來,我們可以轉向抽象求值和 XLA 降低規則。
def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError
jaxpr_type = typecheck_jaxpr(true_jaxpr)
if jaxpr_type != typecheck_jaxpr(false_jaxpr):
raise TypeError
if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
raise TypeError
return jaxpr_type.out_types
abstract_eval_rules[cond_p] = cond_abstract_eval
def cond_translation(c, in_avals, out_avals, in_vals, *, true_jaxpr, false_jaxpr):
del in_avals # Unused
pred, *in_vals = in_vals
op = hlo.IfOp([aval_to_ir_type(aval) for aval in out_avals], pred)
with ir.InsertionPoint(op.true_branch.blocks.append()):
hlo.return_(jaxpr_subcomp(c, true_jaxpr, in_vals))
with ir.InsertionPoint(op.false_branch.blocks.append()):
hlo.return_(jaxpr_subcomp(c, false_jaxpr, in_vals))
return op.results
hlo_translations[cond_p] = cond_translation
out = jit(lambda: cond(False, lambda: 1, lambda: 2))()
print(out)
2
最後,為了支援反向模式自動微分,我們需要部分求值和轉置規則。對於部分求值,我們需要引入另一個 jaxpr 處理工具 _join_jaxpr_res
,以處理將部分求值應用於 true_fun
和 false_fun
通常會產生不同殘差的事實。我們使用 _join_jaxpr_res
使轉換後的 jaxpr 的輸出類型一致(而 _join_jaxpr_consts
處理輸入類型)。
def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):
pred_tracer, *tracers = tracers
assert pred_tracer.pval.is_known
pred = pred_tracer.pval.const
in_uks = [not t.pval.is_known for t in tracers]
*jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
known_tracers, unknown_tracers = partition_list(in_uks, tracers)
known_vals = [t.pval.const for t in known_tracers]
outs1_res = bind_cond(pred, *known_vals,
true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)
outs1, res = split_list(outs1_res, len(outs1_res) - num_res)
pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))
res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
for v in t_jaxpr2.outs]
eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
[v.aval for v in t_jaxpr2.outs], map(ref, outs2))
for t in outs2: t.recipe = eqn
return merge_lists(out_uks, outs1, outs2)
partial_eval_rules[cond_p] = cond_partial_eval
def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: list[bool]
) -> tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, list[bool], int]:
_, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)
_, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)
out_uks = map(op.or_, t_out_uks, f_out_uks)
t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)
f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)
t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)
t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)
assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)
assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)
num_res = t_nres + f_nres
return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res
def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
) -> tuple[Jaxpr, Jaxpr]:
jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)
out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)
assert out_types1 == out_types2
outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)
outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)
zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]
zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]
new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
return new_jaxpr1, new_jaxpr2
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
out = f_lin(3.14)
print(out)
3.14
def cond_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Atom]]:
pred_unk, *unks_in = unks_in
assert not pred_unk
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
*jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)
t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])
outs1, outs2 = partition_list(unks_out, eqn.out_binders)
residuals, _ = split_list(t_jaxpr2.in_binders, num_res)
eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],
dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),
outs1 + residuals)
eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
outs2)
res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals
return eqn1, eqn2, unks_out, res
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
out = f_lin(3.14)
print(out)
3.14
轉置是 transpose_jaxpr
的一個相當直接的應用。
def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
undef_primals = tuple(type(x) is UndefPrimal for x in invals)
true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
true_jaxpr, false_jaxpr = _join_jaxpr_consts(
true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
res = [x for x in invals if type(x) is not UndefPrimal]
outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,
true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
outs = iter(outs)
return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]
transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
2.0
顯示程式碼單元來源
def pprint_cond(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}
lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>
pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
for x in eqn.inputs)))
return vcat([lhs >> pp(' = ') >> rhs,
pp_jaxpr(true_jaxpr).indent(2),
pp_jaxpr(false_jaxpr).indent(2)])
pp_rules[cond_p] = pprint_cond