主要概念#
本節簡要介紹 JAX 套件的一些主要概念。
JAX 陣列 (jax.Array
)#
JAX 中的預設陣列實作是 jax.Array
。在許多方面,它與您可能熟悉的 NumPy 套件中的 numpy.ndarray
型別相似,但它有一些重要的差異。
陣列建立#
我們通常不直接呼叫 jax.Array
建構函式,而是透過 JAX API 函數建立陣列。例如,jax.numpy
提供了熟悉的 NumPy 風格陣列建構功能,例如 jax.numpy.zeros()
、jax.numpy.linspace()
、jax.numpy.arange()
等。
import jax
import jax.numpy as jnp
x = jnp.arange(5)
isinstance(x, jax.Array)
True
如果您在程式碼中使用 Python 型別註解,jax.Array
是 jax 陣列物件的適當註解(請參閱 jax.typing
以獲得更多討論)。
陣列裝置和分片#
JAX Array 物件有一個 devices
方法,可讓您檢查陣列內容的儲存位置。在最簡單的情況下,這將是單個 CPU 裝置
x.devices()
{CpuDevice(id=0)}
一般來說,陣列可能會分片到多個裝置上,其方式可以透過 sharding
屬性檢查
x.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)
這裡的陣列位於單個裝置上,但一般來說,JAX 陣列可以分片到多個裝置,甚至多個主機。若要閱讀更多關於分片陣列和平行運算的資訊,請參閱平行程式設計簡介
轉換#
除了操作陣列的函數外,JAX 還包含許多對 JAX 函數進行操作的轉換。這些包括
以及其他幾個。轉換接受函數作為參數,並傳回新的轉換函數。例如,以下是如何 JIT 編譯簡單 SELU 函數的方法
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05
您通常會看到為了方便起見,使用 Python 的裝飾器語法套用轉換
@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
jit()
、vmap()
、grad()
等轉換是有效使用 JAX 的關鍵,我們將在後面的章節中詳細介紹它們。
追蹤#
轉換背後的魔力是追蹤器的概念。追蹤器是陣列物件的抽象佔位符,並傳遞給 JAX 函數,以便提取函數編碼的操作序列。
您可以透過在轉換後的 JAX 程式碼中列印任何陣列值來看到這一點;例如
@jax.jit
def f(x):
print(x)
return x + 1
x = jnp.arange(5)
result = f(x)
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace>
列印的值不是陣列 x
,而是一個 Tracer
實例,它代表 x
的基本屬性,例如其 shape
和 dtype
。透過使用追蹤值執行函數,JAX 可以在實際執行操作之前確定函數編碼的操作序列:jit()
、vmap()
和 grad()
等轉換然後可以將此輸入操作序列映射到轉換後的操作序列。
Jaxpr#
JAX 有自己的操作序列的中間表示法,稱為 jaxpr。jaxpr(JAX exPRession 的縮寫)是函數式程式的簡單表示法,包含一系列基本運算。
例如,考慮我們上面定義的 selu
函數
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
我們可以針對特定輸入使用 jax.make_jaxpr()
公用程式將此函數轉換為 jaxpr
x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
b:bool[5] = gt a 0.0
c:f32[5] = exp a
d:f32[5] = mul 1.6699999570846558 c
e:f32[5] = sub d 1.6699999570846558
f:f32[5] = pjit[
name=_where
jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
j:f32[5] = select_n g i h
in (j,) }
] b a e
k:f32[5] = mul 1.0499999523162842 f
in (k,) }
將其與 Python 函數定義進行比較,我們看到它編碼了函數表示的精確操作序列。我們將在JAX 內部機制:jaxpr 語言中更深入地探討 jaxpr。
Pytree#
JAX 函數和轉換基本上是對陣列進行操作,但在實務上,編寫適用於陣列集合的程式碼很方便:例如,神經網路可能會將其參數組織在具有有意義鍵的陣列字典中。JAX 不是逐個處理此類結構,而是依賴 pytree 抽象以一致的方式處理此類集合。
以下是一些可以作為 pytree 處理的物件範例
# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
[1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple
class Params(NamedTuple):
a: int
b: float
params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]
JAX 有許多用於處理 Pytree 的通用公用程式;例如,函數 jax.tree.map()
可用於將函數映射到樹中的每個葉節點,而 jax.tree.reduce()
可用於在樹中的葉節點上應用縮減。
您可以在使用 pytree教學中了解更多資訊。
偽隨機數#
一般來說,JAX 努力與 NumPy 相容,但偽隨機數產生是一個值得注意的例外。NumPy 支援一種基於全域 state
的偽隨機數產生方法,可以使用 numpy.random.seed()
進行設定。全域隨機狀態與 JAX 的計算模型互動不佳,並且難以在不同的執行緒、進程和裝置之間強制執行可重現性。JAX 反而透過隨機 key
明確地追蹤狀態
from jax import random
key = random.key(43)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 43]
金鑰實際上是 NumPy 隱藏狀態物件的佔位符,但我們將其明確地傳遞給 jax.random()
函數。重要的是,隨機函數會消耗金鑰,但不會修改它:將相同的金鑰物件饋送到隨機函數始終會產生相同的樣本。
print(random.normal(key))
print(random.normal(key))
0.07520543
0.07520543
經驗法則是:永遠不要重複使用金鑰(除非您想要相同的輸出)。
為了產生不同且獨立的樣本,您必須在將金鑰傳遞給隨機函數之前明確地 split()
金鑰
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.
val = random.normal(subkey)
del subkey # The subkey is consumed by normal().
print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.
draw 0: -1.9133632183074951
draw 1: -1.4749839305877686
draw 2: -0.36703771352767944
請注意,此程式碼是執行緒安全的,因為本機隨機狀態消除了涉及全域狀態的可能競爭條件。jax.random.split()
是一個確定性函數,可將一個 key
轉換為多個獨立(在偽隨機性意義上)的金鑰。
有關 JAX 中偽隨機數的更多資訊,請參閱偽隨機數教學。