使用 JIT 的控制流程與邏輯運算子#
當以 eager 方式執行時(在 jit
之外),JAX 程式碼與 Python 控制流程和邏輯運算子的運作方式與 Numpy 程式碼相同。將控制流程和邏輯運算子與 jit
搭配使用會更複雜。
簡而言之,Python 控制流程和邏輯運算子會在 JIT 編譯時進行評估,因此編譯後的函式代表通過控制流程圖的單一路徑(邏輯運算子透過短路影響路徑)。如果路徑取決於輸入的值,則(預設情況下)無法 JIT 編譯該函式。路徑可能取決於輸入的形狀或 dtype,並且每次在具有新形狀或 dtype 的輸入上呼叫函式時,都會重新編譯該函式。
from jax import grad, jit
import jax.numpy as jnp
例如,這可以運作
@jit
def f(x):
for i in range(3):
x = 2 * x
return x
print(f(3))
24
這個也可以
@jit
def g(x):
y = 0.
for i in range(x.shape[0]):
y = y + x[i]
return y
print(g(jnp.array([1., 2., 3.])))
6.0
但這個不行,至少預設情況下不行
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
# This will fail!
f(2)
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
Cell In[4], line 9
6 return -4 * x
8 # This will fail!
----> 9 f(2)
[... skipping hidden 13 frame]
Cell In[4], line 3, in f(x)
1 @jit
2 def f(x):
----> 3 if x < 3:
4 return 3. * x ** 2
5 else:
[... skipping hidden 1 frame]
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py:1498, in concretization_function_error.<locals>.error(self, arg)
1497 def error(self, arg):
-> 1498 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_827/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
這個也不行
@jit
def g(x):
return (x > 0) and (x < 3)
# This will fail!
g(2)
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
Cell In[5], line 6
3 return (x > 0) and (x < 3)
5 # This will fail!
----> 6 g(2)
[... skipping hidden 13 frame]
Cell In[5], line 3, in g(x)
1 @jit
2 def g(x):
----> 3 return (x > 0) and (x < 3)
[... skipping hidden 1 frame]
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py:1498, in concretization_function_error.<locals>.error(self, arg)
1497 def error(self, arg):
-> 1498 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_827/543860509.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
怎麼回事!?
當我們 jit
編譯函式時,我們通常希望編譯一個適用於許多不同引數值的函式版本,以便我們可以快取並重複使用編譯後的程式碼。這樣我們就不必在每次函式評估時都重新編譯。
例如,如果我們在陣列 jnp.array([1., 2., 3.], jnp.float32)
上評估 @jit
函式,我們可能希望編譯可以重複使用的程式碼,以在 jnp.array([4., 5., 6.], jnp.float32)
上評估該函式,以節省編譯時間。
為了獲得適用於許多不同引數值的 Python 程式碼視圖,JAX 使用 ShapedArray
抽象化作為輸入來追蹤它,其中每個抽象值代表具有固定形狀和 dtype 的所有陣列值的集合。例如,如果我們使用抽象值 ShapedArray((3,), jnp.float32)
進行追蹤,我們會獲得函式的視圖,該視圖可以重複用於相應陣列集合中的任何具體值。這表示我們可以節省編譯時間。
但這裡有一個權衡:如果我們在未提交到特定具體值的 ShapedArray((), jnp.float32)
上追蹤 Python 函式,當我們遇到類似 if x < 3
的行時,表達式 x < 3
會評估為抽象 ShapedArray((), jnp.bool_)
,它代表集合 {True, False}
。當 Python 嘗試將其強制轉換為具體的 True
或 False
時,我們會收到錯誤:我們不知道要採用哪個分支,並且無法繼續追蹤!權衡是,透過更高層次的抽象化,我們獲得了 Python 程式碼更通用的視圖(因此節省了重新編譯),但我們需要對 Python 程式碼施加更多限制才能完成追蹤。
好消息是您可以自己控制這種權衡。透過讓 jit
在更精細的抽象值上進行追蹤,您可以放寬可追蹤性限制。例如,使用 jit
的 static_argnames
(或 static_argnums
)引數,我們可以指定在某些引數的具體值上進行追蹤。以下是該範例函式再次呈現
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
f = jit(f, static_argnames='x')
print(f(2.))
12.0
這是另一個範例,這次涉及迴圈
def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f = jit(f, static_argnames='n')
f(jnp.array([2., 3., 4.]), 2)
Array(5., dtype=float32)
實際上,迴圈會被靜態展開。JAX 也可以在更高層次的抽象化(例如 Unshaped
)上進行追蹤,但目前這並非任何轉換的預設值
️⚠️ 具有引數值相依形狀的函式
這些控制流程問題也會以更微妙的方式出現:我們想要 jit 的數值函式無法根據引數值來特化內部陣列的形狀(根據引數形狀特化是可以的)。作為一個簡單的範例,讓我們建立一個輸出恰好取決於輸入變數 length
的函式。
def example_fun(length, val):
return jnp.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
[4. 4. 4. 4. 4.]
bad_example_jit = jit(example_fun)
# this will fail:
bad_example_jit(10, 4)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[9], line 3
1 bad_example_jit = jit(example_fun)
2 # this will fail:
----> 3 bad_example_jit(10, 4)
[... skipping hidden 13 frame]
Cell In[8], line 2, in example_fun(length, val)
1 def example_fun(length, val):
----> 2 return jnp.ones((length,)) * val
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:6174, in ones(shape, dtype, device)
6172 raise TypeError("expected sequence object with len >= 0 or a single integer")
6173 if (m := _check_forgot_shape_tuple("ones", shape, dtype)): raise TypeError(m)
-> 6174 shape = canonicalize_shape(shape)
6175 dtypes.check_user_dtype_supported(dtype, "ones")
6176 return lax.full(shape, 1, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:102, in canonicalize_shape(shape, context)
100 return core.canonicalize_shape((shape,), context)
101 else:
--> 102 return core.canonicalize_shape(shape, context)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py:1643, in canonicalize_shape(shape, context)
1641 except TypeError:
1642 pass
-> 1643 raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun at /tmp/ipykernel_827/1210496444.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
# static_argnames tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnames='length')
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]
如果範例中的 length
很少變更,則 static_argnames
可能很方便,但如果它變更很多,那將是災難性的!
最後,如果您的函式具有全域副作用,則 JAX 的追蹤器可能會導致奇怪的事情發生。一個常見的陷阱是嘗試在 jit 函式內部列印陣列
@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>
Array(4, dtype=int32, weak_type=True)
結構化控制流程基本運算#
JAX 中有更多控制流程的選項。假設您想要避免重新編譯,但仍然想要使用可追蹤的控制流程,並避免展開大型迴圈。那麼您可以使用以下 4 個結構化控制流程基本運算
lax.cond
可微分lax.while_loop
正向模式可微分lax.fori_loop
一般而言為正向模式可微分;如果端點是靜態的,則為正向和反向模式可微分。lax.scan
可微分
cond
#
python 等效程式碼
def cond(pred, true_fun, false_fun, operand):
if pred:
return true_fun(operand)
else:
return false_fun(operand)
from jax import lax
operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)
Array([-1.], dtype=float32)
jax.lax
提供了另外兩個函式,允許在動態述詞上進行分支
lax.select
類似於lax.cond
的批次版本,選項表示為預先計算的陣列,而不是函式。lax.switch
類似於lax.cond
,但允許在任意數量的可呼叫選項之間進行切換。
此外,jax.numpy
為這些函式提供了幾個 numpy 樣式的介面
jnp.where
帶有三個引數是lax.select
的 numpy 樣式包裝函式。jnp.piecewise
是lax.switch
的 numpy 樣式包裝函式,但根據布林條件列表而不是單一純量索引進行切換。jnp.select
具有類似於jnp.piecewise
的 API,但選項以預先計算的陣列而不是函式形式給出。它是根據對lax.select
的多次呼叫來實作的。
while_loop
#
python 等效程式碼
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
init_val = 0
cond_fun = lambda x: x < 10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)
fori_loop
#
python 等效程式碼
def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
Array(45, dtype=int32, weak_type=True)
摘要#
\(\ast\) = 與引數值無關的迴圈條件 - 展開迴圈
邏輯運算子#
jax.numpy
提供了 logical_and
、logical_or
和 logical_not
,它們在陣列上逐元素運算,並且可以在 jit
下進行評估而無需重新編譯。與它們的 Numpy 對應項一樣,二元運算子不會短路。位元運算子(&
、|
、~
)也可以與 jit
一起使用。
例如,考慮一個檢查其輸入是否為正偶數整數的函式。當輸入為純量時,純 Python 和 JAX 版本會給出相同的答案。
def python_check_positive_even(x):
is_even = x % 2 == 0
# `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
return is_even and (x > 0)
@jit
def jax_check_positive_even(x):
is_even = x % 2 == 0
# `logical_and` does not short circuit, so `x > 0` is always evaluated.
return jnp.logical_and(is_even, x > 0)
print(python_check_positive_even(24))
print(jax_check_positive_even(24))
True
True
當具有 logical_and
的 JAX 版本應用於陣列時,它會傳回逐元素的值。
x = jnp.array([-1, 2, 5])
print(jax_check_positive_even(x))
[False True False]
即使沒有 jit
,當 Python 邏輯運算子應用於多個元素的 JAX 陣列時,也會發生錯誤。這複製了 NumPy 的行為。
print(python_check_positive_even(x))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[17], line 1
----> 1 print(python_check_positive_even(x))
Cell In[15], line 4, in python_check_positive_even(x)
2 is_even = x % 2 == 0
3 # `and` short-circults, so when `is_even` is `False`, `x > 0` is not evaluated.
----> 4 return is_even and (x > 0)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/array.py:292, in ArrayImpl.__bool__(self)
291 def __bool__(self):
--> 292 core.check_bool_conversion(self)
293 return bool(self._value)
File ~/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py:655, in check_bool_conversion(arr)
652 raise ValueError("The truth value of an empty array is ambiguous. Use"
653 " `array.size > 0` to check that an array is not empty.")
654 if arr.size > 1:
--> 655 raise ValueError("The truth value of an array with more than one element"
656 " is ambiguous. Use a.any() or a.all()")
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Python 控制流程 + 自動微分#
請記住,上述關於控制流程和邏輯運算子的限制僅與 jit
相關。如果您只想將 grad
應用於您的 python 函式,而沒有 jit
,您可以像使用 Autograd(或 Pytorch 或 TF Eager)一樣,毫無問題地使用常規 Python 控制流程建構。
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
print(grad(f)(2.)) # ok!
print(grad(f)(4.)) # ok!
12.0
-4.0