JAX 思考方式#

Open in Colab Open in Kaggle

JAX 提供了一個簡單而強大的 API,用於編寫加速的數值程式碼,但有效地使用 JAX 有時需要額外的考量。本文檔旨在幫助建立對 JAX 如何運作的由下而上的理解,以便您可以更有效地使用它。

JAX 與 NumPy#

關鍵概念

  • JAX 為了方便起見,提供了受 NumPy 啟發的介面。

  • 透過鴨子型別,JAX 陣列通常可以用作 NumPy 陣列的直接替換。

  • 與 NumPy 陣列不同,JAX 陣列始終是不可變的。

NumPy 提供了一個眾所周知、功能強大的 API,用於處理數值資料。為了方便起見,JAX 提供了 jax.numpy,它緊密地反映了 numpy API,並提供了輕鬆進入 JAX 的途徑。幾乎所有可以使用 numpy 完成的事情都可以使用 jax.numpy 完成

import matplotlib.pyplot as plt
import numpy as np

x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np);
../_images/b2db475a8afa1d2e364a801f61f7b347b75a355e9da0be2f015a2d1aefdea45c.png
import jax.numpy as jnp

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);
../_images/487cfe9c47318bd2e5849cf09dc8048af87a3364e9f0e0e524de8e950911888e.png

程式碼區塊除了將 np 替換為 jnp 之外,是相同的,結果也相同。正如我們所看到的,JAX 陣列通常可以直接代替 NumPy 陣列用於繪圖等操作。

陣列本身是以不同的 Python 型別實作的

type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib.xla_extension.ArrayImpl

Python 的 鴨子型別 允許 JAX 陣列和 NumPy 陣列在許多地方可以互換使用。

然而,JAX 陣列和 NumPy 陣列之間有一個重要的區別:JAX 陣列是不可變的,這意味著一旦建立,它們的內容就不能被更改。

以下是在 NumPy 中變更陣列的範例

# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10  1  2  3  4  5  6  7  8  9]

JAX 中的等效操作會導致錯誤,因為 JAX 陣列是不可變的

%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.dev.org.tw/en/latest/_autosummary/jax.numpy.ndarray.at.html

對於更新個別元素,JAX 提供了一個 索引更新語法,它會傳回一個更新後的副本

y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9]

NumPy、lax 和 XLA:JAX API 分層#

關鍵概念

  • jax.numpy 是一個高階包裝器,提供了熟悉的介面。

  • jax.lax 是一個較低階的 API,它更嚴格且通常更強大。

  • 所有 JAX 運算都是根據 XLA(加速線性代數編譯器)中的運算來實作的。

如果您查看 jax.numpy 的原始碼,您會看到所有運算最終都以 jax.lax 中定義的函式來表示。您可以將 jax.lax 視為一個更嚴格,但通常更強大的 API,用於處理多維陣列。

例如,雖然 jax.numpy 會隱式地提升引數以允許混合資料型別之間的運算,但 jax.lax 不會

import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0)  # jax.lax API requires explicit type promotion.
TypeError: lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).

如果直接使用 jax.lax,您必須在這種情況下顯式地進行型別提升

lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)

除了這種嚴格性之外,jax.lax 還為某些比 NumPy 支援的更通用的運算提供了高效的 API。

例如,考慮一個 1D 卷積,它可以用 NumPy 這樣表示

x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

在底層,此 NumPy 運算被轉換為由 lax.conv_general_dilated 實作的更通用的卷積

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

這是一個批次卷積運算,旨在高效地處理深度神經網路中常用的卷積型別。它需要更多的樣板程式碼,但比 NumPy 提供的卷積更靈活和可擴展(有關 JAX 卷積的更多詳細資訊,請參閱 JAX 中的卷積)。

在其核心,所有 jax.lax 運算都是 XLA 中運算的 Python 包裝器;例如,這裡的卷積實作是由 XLA:ConvWithGeneralPadding 提供的。每個 JAX 運算最終都以這些基本 XLA 運算來表示,這也是實現即時 (JIT) 編譯的原因。

是否使用 JIT#

關鍵概念

  • 預設情況下,JAX 會依序一次執行一個運算。

  • 使用即時 (JIT) 編譯裝飾器,可以一起最佳化運算序列並一次執行。

  • 並非所有 JAX 程式碼都可以進行 JIT 編譯,因為它要求陣列形狀是靜態的,並且在編譯時已知。

所有 JAX 運算都以 XLA 來表示的事實,使得 JAX 可以使用 XLA 編譯器非常有效率地執行程式碼區塊。

例如,考慮這個函式,它正規化 2D 矩陣的列,以 jax.numpy 運算來表示

import jax.numpy as jnp

def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

可以使用 jax.jit 轉換建立函式的即時編譯版本

from jax import jit
norm_compiled = jit(norm)

此函式傳回與原始函式相同的結果,達到標準浮點精度

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True

但是由於編譯(包括運算融合、避免分配臨時陣列以及許多其他技巧),在 JIT 編譯的情況下,執行時間可能會快幾個數量級(請注意使用 block_until_ready() 來考慮 JAX 的 非同步分派

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
755 μs ± 4.35 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
367 μs ± 2.73 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

也就是說,jax.jit 確實有局限性:特別是,它要求所有陣列都具有靜態形狀。這意味著某些 JAX 運算與 JIT 編譯不相容。

例如,此運算可以在逐運算模式下執行

def get_negatives(x):
  return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)

但是,如果您嘗試在 jit 模式下執行它,則會傳回錯誤

jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

這是因為該函式產生了一個陣列,其形狀在編譯時是未知的:輸出的尺寸取決於輸入陣列的值,因此它與 JIT 不相容。

JIT 機制:追蹤和靜態變數#

關鍵概念

  • JIT 和其他 JAX 轉換透過追蹤函式來確定其對特定形狀和型別輸入的影響。

  • 您不希望被追蹤的變數可以標記為靜態

為了有效使用 jax.jit,了解其運作方式很有用。讓我們在 JIT 編譯的函式中放入幾個 print() 語句,然後呼叫該函式

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace>
Array([0.25773212, 5.3623195 , 5.403243  ], dtype=float32)

請注意,print 語句會執行,但它印出的不是我們傳遞給函式的資料,而是代替它們的追蹤器物件。

這些追蹤器物件是 jax.jit 用於提取函式指定的運算序列的物件。基本追蹤器是編碼陣列形狀dtype 的替代品,但與值無關。然後,可以在 XLA 中有效地將此記錄的計算序列應用於具有相同形狀和 dtype 的新輸入,而無需重新執行 Python 程式碼。

當我們再次在匹配的輸入上呼叫編譯後的函式時,不需要重新編譯,並且不會印出任何內容,因為結果是在編譯後的 XLA 中計算,而不是在 Python 中計算

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)

提取的運算序列編碼在 JAX 表達式或簡稱 jaxpr 中。您可以使用 jax.make_jaxpr 轉換來檢視 jaxpr

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

請注意這樣做的一個後果:由於 JIT 編譯是在沒有陣列內容資訊的情況下完成的,因此函式中的控制流程語句不能依賴於追蹤的值。例如,這會失敗

@jit
def f(x, neg):
  return -x if neg else x

f(1, True)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_2639/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

如果您有不想被追蹤的變數,則可以將它們標記為靜態以進行 JIT 編譯

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)
Array(-1, dtype=int32, weak_type=True)

請注意,使用不同的靜態引數呼叫 JIT 編譯的函式會導致重新編譯,因此該函式仍然可以按預期運作

f(1, False)
Array(1, dtype=int32, weak_type=True)

了解哪些值和運算將是靜態的,哪些將被追蹤,是有效使用 jax.jit 的關鍵部分。

靜態運算與追蹤運算#

關鍵概念

  • 就像值可以是靜態的或追蹤的一樣,運算也可以是靜態的或追蹤的。

  • 靜態運算在編譯時在 Python 中評估;追蹤運算在執行時在 XLA 中編譯和評估。

  • 對於您想要靜態的運算,請使用 numpy;對於您想要追蹤的運算,請使用 jax.numpy

靜態值和追蹤值之間的這種區別使得思考如何保持靜態值的靜態性變得重要。考慮這個函式

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_2639/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /tmp/ipykernel_2639/1983583872.py:6 (f)

這會失敗,並出現一個錯誤,指出找到了追蹤器而不是整數型別的具體值的一維序列。讓我們在函式中新增一些 print 語句,以了解為什麼會發生這種情況

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace>

請注意,雖然 x 被追蹤,但 x.shape 是一個靜態值。但是,當我們在這個靜態值上使用 jnp.arrayjnp.prod 時,它會變成一個追蹤的值,此時它不能在需要靜態輸入的函式(例如 reshape())中使用(回想一下:陣列形狀必須是靜態的)。

一個有用的模式是對應該是靜態的運算(即在編譯時完成)使用 numpy,而對應該被追蹤的運算(即在執行時編譯和執行)使用 jax.numpy。對於這個函式,它可能看起來像這樣

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)

因此,JAX 程式中的一個標準慣例是 import numpy as npimport jax.numpy as jnp,以便兩個介面都可用於更精細地控制運算是以靜態方式(使用 numpy,在編譯時執行一次)還是以追蹤方式(使用 jax.numpy,在執行時最佳化)執行。