🔪 JAX - The Sharp Bits 🔪#
當您在義大利鄉間漫步時,人們會毫不猶豫地告訴您 JAX 具有 “una anima di pura programmazione funzionale”。
JAX 是一種用於表達和組合數值程式轉換的語言。JAX 也能夠為 CPU 或加速器 (GPU/TPU) 編譯數值程式。JAX 非常適合許多數值和科學程式,但前提是它們必須以某些約束條件撰寫,我們將在下方說明。
import numpy as np
from jax import jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp
🔪 純函數#
JAX 轉換和編譯旨在僅適用於功能純粹的 Python 函數:所有輸入資料都透過函數參數傳遞,所有結果都透過函數結果輸出。如果使用相同的輸入調用,純函數將始終返回相同的結果。
以下是一些非功能純粹函數的範例,JAX 對這些函數的行為與 Python 直譯器不同。請注意,這些行為並非 JAX 系統保證;使用 JAX 的正確方法是僅將其用於功能純粹的 Python 函數。
def impure_print_side_effect(x):
print("Executing function") # This is a side-effect
return x
# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))
# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
Executing function
First call: 4.0
Second call: 5.0
Executing function
Third call, different type: [5.]
g = 0.
def impure_uses_globals(x):
return x + g
# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10. # Update the global
# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))
# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call: 4.0
Second call: 5.0
Third call, different type: [14.]
g = 0.
def impure_saves_global(x):
global g
g = x
return x
# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g) # Saved global has an internal JAX value
First call: 4.0
Saved global: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>
即使 Python 函數實際上在內部使用有狀態物件,只要它不讀取或寫入外部狀態,它仍然可以是功能純粹的
def pure_uses_internal_state(x):
state = dict(even=0, odd=0)
for i in range(10):
state['even' if i % 2 == 0 else 'odd'] += x
return state['even'] + state['odd']
print(jit(pure_uses_internal_state)(5.))
50.0
不建議在任何您想要 jit
的 JAX 函數或任何控制流程基本運算中使用迭代器。原因是迭代器是一個 python 物件,它會引入狀態來檢索下一個元素。因此,它與 JAX 的函數式程式設計模型不相容。在下面的程式碼中,有一些嘗試將迭代器與 JAX 一起使用的不正確範例。它們大多數會返回錯誤,但有些會給出意想不到的結果。
import jax.numpy as jnp
from jax import make_jaxpr
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0
# lax.scan
def func11(arr, extra):
ones = jnp.ones(arr.shape)
def body(carry, aelems):
ae1, ae2 = aelems
return (carry + ae1 * ae2 + extra, carry)
return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error
# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
45
0
🔪 原位更新#
在 Numpy 中,您習慣這樣做
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
但是,如果我們嘗試原位更新 JAX 裝置陣列,我們會收到錯誤! (☉_☉)
%xmode Minimal
Exception reporting mode: Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.dev.org.tw/en/latest/_autosummary/jax.numpy.ndarray.at.html
允許變數原位變更會使程式分析和轉換變得困難。JAX 要求程式是純函數。
相反地,JAX 提供使用 .at
屬性於 JAX 陣列的功能性陣列更新。
️⚠️ 在 jit
編譯的程式碼和 lax.while_loop
或 lax.fori_loop
內部,切片的大小不能是引數值的函數,而只能是引數形狀的函數 – 切片起始索引沒有此限制。請參閱下方的控制流程章節,以取得有關此限制的更多資訊。
陣列更新:x.at[idx].set(y)
#
例如,上面的更新可以寫成
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
updated array:
[[0. 0. 0.]
[1. 1. 1.]
[0. 0. 0.]]
與 NumPy 版本不同,JAX 的陣列更新函數以異地 (out-of-place) 方式運作。也就是說,更新後的陣列會作為新陣列返回,而原始陣列不會被更新修改。
print("original array unchanged:\n", jax_array)
original array unchanged:
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
但是,在 jit 編譯的程式碼中,如果 x.at[idx].set(y)
的輸入值 x
未被重複使用,編譯器會最佳化陣列更新以原位發生。
使用其他運算的陣列更新#
索引陣列更新不限於僅覆寫值。例如,我們可以執行索引加法如下
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)
new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]]
有關索引陣列更新的更多詳細資訊,請參閱 .at
屬性的文件。
🔪 越界索引#
在 Numpy 中,您習慣於在超出陣列邊界索引時拋出錯誤,就像這樣
np.arange(10)[11]
IndexError: index 11 is out of bounds for axis 0 with size 10
但是,從加速器上運行的程式碼引發錯誤可能很困難或不可能。因此,JAX 必須為越界索引選擇一些非錯誤行為 (類似於無效浮點運算如何導致 NaN
)。當索引操作是陣列索引更新 (例如 index_add
或類似 scatter
的基本運算) 時,將會跳過越界索引的更新;當操作是陣列索引檢索 (例如 NumPy 索引或類似 gather
的基本運算) 時,由於必須返回某些內容,因此索引會被鉗制在陣列的邊界內。例如,陣列的最後一個值將從此索引操作返回
jnp.arange(10)[11]
Array(9, dtype=int32)
如果您想要更精細地控制越界索引的行為,您可以使用 ndarray.at
的可選參數;例如
jnp.arange(10.0).at[11].get()
Array(9., dtype=float32)
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Array(nan, dtype=float32)
請注意,由於索引檢索的這種行為,jnp.nanargmin
和 jnp.nanargmax
等函數對於由 NaN 組成的切片會返回 -1,而 Numpy 會拋出錯誤。
另請注意,由於上述兩種行為不是彼此的反向操作,因此反向模式自動微分 (將索引更新轉換為索引檢索,反之亦然) 將無法保留越界索引的語意。因此,將 JAX 中的越界索引視為 未定義行為 的一種情況可能是個好主意。
🔪 非陣列輸入:NumPy vs. JAX#
NumPy 通常很樂意接受 Python 列表或元組作為其 API 函數的輸入
np.sum([1, 2, 3])
np.int64(6)
JAX 偏離了這一點,通常會返回有用的錯誤
jnp.sum([1, 2, 3])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
這是一個經過深思熟慮的設計選擇,因為將列表或元組傳遞給追蹤函數可能會導致靜默的效能下降,否則可能難以檢測到。
例如,考慮以下允許列表輸入的寬鬆版本 jnp.sum
def permissive_sum(x):
return jnp.sum(jnp.array(x))
x = list(range(10))
permissive_sum(x)
Array(45, dtype=int32)
輸出符合我們的預期,但這隱藏了底層潛在的效能問題。在 JAX 的追蹤和 JIT 編譯模型中,Python 列表或元組中的每個元素都被視為單獨的 JAX 變數,並單獨處理並推送到裝置。這可以在上面 permissive_sum
函數的 jaxpr 中看到
make_jaxpr(permissive_sum)(x)
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
j:i32[]. let
k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h
s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i
t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j
u:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] k
v:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] l
w:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] m
x:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] n
y:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] o
z:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] p
ba:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] q
bb:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] r
bc:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] s
bd:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] t
be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd
bf:i32[] = reduce_sum[axes=(0,)] be
in (bf,) }
列表的每個條目都作為單獨的輸入處理,導致追蹤和編譯開銷隨著列表的大小線性增長。為了防止此類意外發生,JAX 避免了將列表和元組隱式轉換為陣列。
如果您想將元組或列表傳遞給 JAX 函數,您可以先將其顯式轉換為陣列
jnp.sum(jnp.array(x))
Array(45, dtype=int32)
🔪 隨機數#
JAX 的偽隨機數生成在重要方面與 Numpy 的不同。如需快速入門指南,請參閱偽隨機數。有關更多詳細資訊,請參閱偽隨機數教學。
🔪 控制流程#
🔪 動態形狀#
在 jax.jit
、jax.vmap
、jax.grad
等轉換中使用的 JAX 程式碼要求所有輸出陣列和中間陣列都具有靜態形狀:也就是說,形狀不能依賴於其他陣列中的值。
例如,如果您要實作自己的 jnp.nansum
版本,您可能會從類似這樣的程式碼開始
def nansum(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
x_without_nans = x[mask]
return x_without_nans.sum()
在 JIT 和其他轉換之外,這可以如預期般運作
x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))
10.0
如果您嘗試將 jax.jit
或其他轉換應用於此函數,它會出錯
jax.jit(nansum)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])
See https://jax.dev.org.tw/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
問題在於 x_without_nans
的大小取決於 x
中的值,這也表示其大小是動態的。通常在 JAX 中,可以透過其他方式解決對動態大小陣列的需求。例如,在這裡可以使用 jnp.where
的三引數形式將 NaN 值替換為零,從而在避免動態形狀的同時計算相同的結果
@jax.jit
def nansum_2(x):
mask = ~jnp.isnan(x) # boolean mask selecting non-nan values
return jnp.where(mask, x, 0).sum()
print(nansum_2(x))
10.0
在其他出現動態形狀陣列的情況下,也可以使用類似的技巧。
🔪 NaNs#
除錯 NaNs#
如果您想追蹤 NaN 在您的函數或梯度中發生的位置,您可以透過以下方式開啟 NaN 檢查器
設定
JAX_DEBUG_NANS=True
環境變數;在您的主檔案頂部附近加入
jax.config.update("jax_debug_nans", True)
;將
jax.config.parse_flags_with_absl()
加入您的主檔案,然後使用類似--jax_debug_nans=True
的命令列 flag 設定選項;
這將導致計算在產生 NaN 時立即出錯。開啟此選項會為 XLA 產生的每個浮點型別值新增 NaN 檢查。這表示值會被拉回主機並作為 ndarray 檢查,適用於不在 @jit
下的每個基本運算。對於 @jit
下的程式碼,會檢查每個 @jit
函數的輸出,如果存在 NaN,它將以解除最佳化的逐運算元模式重新運行該函數,從而有效地一次移除一個 @jit
層級。
可能會出現棘手的情況,例如僅在 @jit
下發生的 NaN,但在解除最佳化模式下不會產生。在這種情況下,您會看到列印出的警告訊息,但您的程式碼將繼續執行。
如果 NaN 是在梯度評估的反向傳遞中產生的,當堆疊追蹤中較上層的幾個 frame 引發例外時,您將處於 backward_pass 函數中,這本質上是一個簡單的 jaxpr 直譯器,它會反向遍歷基本運算的序列。在下面的範例中,我們使用命令列 env JAX_DEBUG_NANS=True ipython
啟動了 ipython repl,然後執行了以下程式碼
In [1]: import jax.numpy as jnp
In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
103 py_val = device_buffer.to_py()
104 if np.any(np.isnan(py_val)):
--> 105 raise FloatingPointError("invalid value")
106 else:
107 return Array(device_buffer, *result_shape)
FloatingPointError: invalid value
產生的 NaN 被捕獲。透過執行 %debug
,我們可以獲得事後除錯器。這也適用於 @jit
下的函數,如下面的範例所示。
In [4]: from jax import jit
In [5]: @jit
...: def f(x, y):
...: a = x * y
...: b = (x + y) / (x - y)
...: c = a + 2
...: return a + b * c
...:
In [6]: x = jnp.array([2., 0.])
In [7]: y = jnp.array([3., 0.])
In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)
... stack trace ...
<ipython-input-5-619b39acbaac> in f(x, y)
2 def f(x, y):
3 a = x * y
----> 4 b = (x + y) / (x - y)
5 c = a + 2
6 return a + b * c
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
當此程式碼在 @jit
函數的輸出中看到 NaN 時,它會調用解除最佳化的程式碼,因此我們仍然可以獲得清晰的堆疊追蹤。而且我們可以執行帶有 %debug
的事後除錯器來檢查所有值,以找出錯誤。
⚠️ 如果您沒有在除錯,則不應開啟 NaN 檢查器,因為它可能會引入大量裝置-主機往返和效能衰退!
⚠️ NaN 檢查器不適用於 pmap
。若要除錯 pmap
程式碼中的 NaN,一種嘗試方法是將 pmap
替換為 vmap
。
🔪 倍精準度 (64 位元)#
目前,JAX 預設強制執行單精準度數字,以減輕 Numpy API 傾向於積極將運算元提升為 double
的情況。對於許多機器學習應用程式來說,這是期望的行為,但它可能會讓您感到驚訝!
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype
/tmp/ipykernel_1169/1258726447.py:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')
若要使用倍精準度數字,您需要在啟動時設定 jax_enable_x64
配置變數。
有幾種方法可以做到這一點
您可以透過設定環境變數
JAX_ENABLE_X64=True
來啟用 64 位元模式。您可以手動設定啟動時的
jax_enable_x64
配置 flag# again, this only works on startup! import jax jax.config.update("jax_enable_x64", True)
您可以使用
absl.app.run(main)
來剖析命令列 flagsimport jax jax.config.config_with_absl()
如果您希望 JAX 為您執行 absl 剖析,也就是說,您不想執行
absl.app.run(main)
,您可以改用import jax if __name__ == '__main__': # calls jax.config.config_with_absl() *and* runs absl parsing jax.config.parse_flags_with_absl()
請注意,#2-#4 適用於 JAX 的任何配置選項。
然後我們可以確認 x64
模式已啟用,例如
import jax
import jax.numpy as jnp
from jax import random
jax.config.update("jax_enable_x64", True)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')
注意事項#
⚠️ XLA 不支援所有後端的 64 位元卷積!
🔪 與 NumPy 的其他差異#
雖然 jax.numpy
盡一切努力複製 numpy 的 API 行為,但確實存在行為不同的邊角案例。上述章節中詳細討論了許多此類案例;在這裡,我們列出其他幾個已知 API 不同的地方。
對於二元運算,JAX 的型別提升規則與 NumPy 使用的規則略有不同。請參閱 型別提升語意 以取得更多詳細資訊。
當執行不安全的型別轉換 (即目標 dtype 無法表示輸入值的轉換) 時,JAX 的行為可能取決於後端,並且通常可能與 NumPy 的行為不同。Numpy 允許透過
casting
引數控制這些情況下的結果 (請參閱np.ndarray.astype
);JAX 不提供任何此類配置,而是直接繼承 XLA:ConvertElementType 的行為。以下是一個不安全轉換的範例,NumPy 和 JAX 之間的結果不同
>>> np.arange(254.0, 258.0).astype('uint8') array([254, 255, 0, 1], dtype=uint8) >>> jnp.arange(254.0, 258.0).astype('uint8') Array([254, 255, 255, 255], dtype=uint8)
當從浮點型別轉換為整數型別或反之亦然的極端值時,通常會出現這種不匹配的情況。
完。#
如果這裡沒有涵蓋任何讓您痛哭流涕的事情,請告訴我們,我們將擴充這些入門級的建議!