主要概念#

本節簡要介紹 JAX 套件的一些主要概念。

JAX 陣列 (jax.Array)#

JAX 中的預設陣列實作是 jax.Array。在許多方面,它與您可能熟悉的 NumPy 套件中的 numpy.ndarray 型別相似,但它有一些重要的差異。

陣列建立#

我們通常不直接呼叫 jax.Array 建構函式,而是透過 JAX API 函數建立陣列。例如,jax.numpy 提供了熟悉的 NumPy 風格陣列建構功能,例如 jax.numpy.zeros()jax.numpy.linspace()jax.numpy.arange() 等。

import jax
import jax.numpy as jnp

x = jnp.arange(5)
isinstance(x, jax.Array)
True

如果您在程式碼中使用 Python 型別註解,jax.Array 是 jax 陣列物件的適當註解(請參閱 jax.typing 以獲得更多討論)。

陣列裝置和分片#

JAX Array 物件有一個 devices 方法,可讓您檢查陣列內容的儲存位置。在最簡單的情況下,這將是單個 CPU 裝置

x.devices()
{CpuDevice(id=0)}

一般來說,陣列可能會分片到多個裝置上,其方式可以透過 sharding 屬性檢查

x.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

這裡的陣列位於單個裝置上,但一般來說,JAX 陣列可以分片到多個裝置,甚至多個主機。若要閱讀更多關於分片陣列和平行運算的資訊,請參閱平行程式設計簡介

轉換#

除了操作陣列的函數外,JAX 還包含許多對 JAX 函數進行操作的轉換。這些包括

以及其他幾個。轉換接受函數作為參數,並傳回新的轉換函數。例如,以下是如何 JIT 編譯簡單 SELU 函數的方法

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05

您通常會看到為了方便起見,使用 Python 的裝飾器語法套用轉換

@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

jit()vmap()grad() 等轉換是有效使用 JAX 的關鍵,我們將在後面的章節中詳細介紹它們。

追蹤#

轉換背後的魔力是追蹤器的概念。追蹤器是陣列物件的抽象佔位符,並傳遞給 JAX 函數,以便提取函數編碼的操作序列。

您可以透過在轉換後的 JAX 程式碼中列印任何陣列值來看到這一點;例如

@jax.jit
def f(x):
  print(x)
  return x + 1

x = jnp.arange(5)
result = f(x)
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace>

列印的值不是陣列 x,而是一個 Tracer 實例,它代表 x 的基本屬性,例如其 shapedtype。透過使用追蹤值執行函數,JAX 可以在實際執行操作之前確定函數編碼的操作序列:jit()vmap()grad() 等轉換然後可以將此輸入操作序列映射到轉換後的操作序列。

Jaxpr#

JAX 有自己的操作序列的中間表示法,稱為 jaxpr。jaxpr(JAX exPRession 的縮寫)是函數式程式的簡單表示法,包含一系列基本運算

例如,考慮我們上面定義的 selu 函數

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

我們可以針對特定輸入使用 jax.make_jaxpr() 公用程式將此函數轉換為 jaxpr

x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0
    c:f32[5] = exp a
    d:f32[5] = mul 1.6699999570846558 c
    e:f32[5] = sub d 1.6699999570846558
    f:f32[5] = pjit[
      name=_where
      jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
          j:f32[5] = select_n g i h
        in (j,) }
    ] b a e
    k:f32[5] = mul 1.0499999523162842 f
  in (k,) }

將其與 Python 函數定義進行比較,我們看到它編碼了函數表示的精確操作序列。我們將在JAX 內部機制:jaxpr 語言中更深入地探討 jaxpr。

Pytree#

JAX 函數和轉換基本上是對陣列進行操作,但在實務上,編寫適用於陣列集合的程式碼很方便:例如,神經網路可能會將其參數組織在具有有意義鍵的陣列字典中。JAX 不是逐個處理此類結構,而是依賴 pytree 抽象以一致的方式處理此類集合。

以下是一些可以作為 pytree 處理的物件範例

# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]

JAX 有許多用於處理 Pytree 的通用公用程式;例如,函數 jax.tree.map() 可用於將函數映射到樹中的每個葉節點,而 jax.tree.reduce() 可用於在樹中的葉節點上應用縮減。

您可以在使用 pytree教學中了解更多資訊。

偽隨機數#

一般來說,JAX 努力與 NumPy 相容,但偽隨機數產生是一個值得注意的例外。NumPy 支援一種基於全域 state 的偽隨機數產生方法,可以使用 numpy.random.seed() 進行設定。全域隨機狀態與 JAX 的計算模型互動不佳,並且難以在不同的執行緒、進程和裝置之間強制執行可重現性。JAX 反而透過隨機 key 明確地追蹤狀態

from jax import random

key = random.key(43)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 43]

金鑰實際上是 NumPy 隱藏狀態物件的佔位符,但我們將其明確地傳遞給 jax.random() 函數。重要的是,隨機函數會消耗金鑰,但不會修改它:將相同的金鑰物件饋送到隨機函數始終會產生相同的樣本。

print(random.normal(key))
print(random.normal(key))
0.07520543
0.07520543

經驗法則是:永遠不要重複使用金鑰(除非您想要相同的輸出)。

為了產生不同且獨立的樣本,您必須在將金鑰傳遞給隨機函數之前明確地 split() 金鑰

for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.
draw 0: -1.9133632183074951
draw 1: -1.4749839305877686
draw 2: -0.36703771352767944

請注意,此程式碼是執行緒安全的,因為本機隨機狀態消除了涉及全域狀態的可能競爭條件。jax.random.split() 是一個確定性函數,可將一個 key 轉換為多個獨立(在偽隨機性意義上)的金鑰。

有關 JAX 中偽隨機數的更多資訊,請參閱偽隨機數教學。