即時編譯#

在本節中,我們將進一步探討 JAX 的運作方式,以及如何使其高效能。我們將討論 jax.jit() 轉換,它將對 JAX Python 函式執行即時 (JIT) 編譯,使其可以在 XLA 中有效率地執行。

JAX 轉換如何運作#

在前一節中,我們討論了 JAX 允許我們轉換 Python 函式。JAX 透過將每個函式簡化為一系列的 基本運算 操作來完成此操作,每個操作代表一個基本的計算單元。

查看函式背後基本運算序列的一種方法是使用 jax.make_jaxpr()

import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }

文件的 JAX 內部機制:jaxpr 語言 章節提供了關於上述輸出含義的更多資訊。

重要的是,請注意 jaxpr 沒有捕捉到函式中存在的副作用:其中沒有任何內容對應於 global_list.append(x)。這是一個特性,而不是錯誤:JAX 轉換旨在理解無副作用(又稱函數式純粹)的程式碼。如果純函式副作用是不熟悉的術語,這在 🔪 JAX - 尖銳之處 🔪:純函式 中有更詳細的解釋。

不純函式是危險的,因為在 JAX 轉換下,它們很可能無法如預期般運作;它們可能會靜默失敗,或產生令人驚訝的下游錯誤,例如洩漏的 Tracers。此外,JAX 通常無法偵測到何時存在副作用。(如果您想要偵錯列印,請使用 jax.debug.print()。若要以效能為代價表達一般副作用,請參閱 jax.experimental.io_callback()。若要以效能為代價檢查 tracer 洩漏,請使用 jax.check_tracer_leaks())。

追蹤時,JAX 會將每個引數包裝在 tracer 物件中。然後,這些 tracer 會記錄在函式呼叫期間對它們執行的所有 JAX 操作(這發生在常規 Python 中)。然後,JAX 使用 tracer 記錄來重建整個函式。該重建的輸出是 jaxpr。由於 tracer 不記錄 Python 副作用,因此它們不會出現在 jaxpr 中。但是,副作用仍然在追蹤本身期間發生。

注意:Python print() 函式不是純函式:文字輸出是函式的副作用。因此,任何 print() 呼叫只會在追蹤期間發生,並且不會出現在 jaxpr 中

def log2_with_print(x):
  print("printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.))
printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }

看看列印出的 x 如何成為 Traced 物件?那是 JAX 內部機制在運作。

Python 程式碼至少執行一次的事實嚴格來說是實作細節,因此不應依賴它。但是,了解它很有用,因為您可以在偵錯時使用它來列印出計算的中間值。

需要理解的關鍵是,jaxpr 捕捉了在給定參數下執行的函式。例如,如果我們有一個 Python 條件式,jaxpr 將只知道我們採取的分支

def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))
{ lambda ; a:i32[3]. let  in (a,) }

JIT 編譯函式#

如前所述,JAX 使操作能夠在使用相同程式碼的情況下在 CPU/GPU/TPU 上執行。讓我們看一個計算縮放指數線性單元 (SELU) 的範例,這是一種常用於深度學習的操作

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()
4.49 ms ± 60.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

上面的程式碼一次向加速器發送一個操作。這限制了 XLA 編譯器最佳化我們函式的能力。

自然地,我們想要做的是盡可能多地向 XLA 編譯器提供程式碼,以便它可以完全最佳化它。為此,JAX 提供了 jax.jit() 轉換,它將 JIT 編譯相容於 JAX 的函式。下面的範例示範了如何使用 JIT 來加速先前的函式。

selu_jit = jax.jit(selu)

# Pre-compile the function before timing...
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()
987 μs ± 1.45 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

以下是剛才發生的事

  1. 我們將 selu_jit 定義為 selu 的編譯版本。

  2. 我們在 x 上呼叫了 selu_jit 一次。這是 JAX 執行追蹤的地方 – 畢竟它需要有一些輸入來包裝在 tracer 中。然後使用 XLA 將 jaxpr 編譯成非常有效率的程式碼,針對您的 GPU 或 TPU 進行最佳化。最後,執行編譯後的程式碼以滿足呼叫。後續對 selu_jit 的呼叫將直接使用編譯後的程式碼,完全跳過 Python 實作。(如果我們沒有單獨包含預熱呼叫,一切仍然可以運作,但編譯時間將包含在基準測試中。它仍然會更快,因為我們在基準測試中執行許多迴圈,但這將不是一個公平的比較。)

  3. 我們計時了編譯版本的執行速度。(請注意 block_until_ready() 的使用,這是由於 JAX 的 非同步調度 所需的)。

為什麼我們不能直接 JIT 所有東西?#

在看完上面的範例後,您可能想知道我們是否應該簡單地將 jax.jit() 應用於每個函式。為了理解為什麼不是這種情況,以及我們何時應該/不應該應用 jit,讓我們首先檢查一些 JIT 無法運作的情況。

# Condition on value of x.

def f(x):
  if x > 0:
    return x
  else:
    return 2 * x

jax.jit(f)(10)  # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_1109/2956679937.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
# While loop conditioned on x and n.

def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

jax.jit(g)(10, 20)  # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_1109/722961019.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

這兩種情況的問題都是我們嘗試使用執行階段值來條件化程式的追蹤時間流程。JIT 內的追蹤值,例如此處的 xn,只能透過其靜態屬性(例如 shapedtype)影響控制流程,而不能透過其值。有關 Python 控制流程和 JAX 之間互動的更多詳細資訊,請參閱 使用 JIT 的控制流程和邏輯運算子

處理此問題的一種方法是重寫程式碼以避免值條件式。另一種方法是使用特殊的 控制流程運算子,例如 jax.lax.cond()。但是,有時這是不可能或不切實際的。在這種情況下,您可以考慮僅 JIT 編譯函式的一部分。例如,如果函式中計算量最大的部分在迴圈內,我們可以僅 JIT 編譯該內部部分(雖然請務必檢查下一節關於快取的內容,以避免搬石頭砸自己的腳)

# While loop conditioned on x and n with a jitted body.

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

g_inner_jitted(10, 20)
Array(30, dtype=int32, weak_type=True)

將引數標記為靜態#

如果我們真的需要 JIT 編譯一個在輸入值上具有條件的函式,我們可以透過指定 static_argnumsstatic_argnames 來告訴 JAX 針對特定輸入使用較不抽象的 tracer。這樣做的代價是,產生的 jaxpr 和編譯後的成品取決於傳遞的特定值,因此 JAX 將必須針對指定的靜態輸入的每個新值重新編譯函式。只有在保證函式會看到有限組靜態值的情況下,這才是一個好的策略。

f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))
10
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20))
30

若要在使用 jit 作為裝飾器時指定此類引數,常見的模式是使用 Python 的 functools.partial()

from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

print(g_jit_decorated(10, 20))
30

JIT 和快取#

由於第一次 JIT 呼叫的編譯開銷,理解 jax.jit() 如何以及何時快取先前的編譯對於有效地使用它至關重要。

假設我們定義 f = jax.jit(g)。當我們第一次調用 f 時,它將被編譯,並且產生的 XLA 程式碼將被快取。後續對 f 的呼叫將重複使用快取的程式碼。這就是 jax.jit 如何彌補預先編譯成本的方式。

如果我們指定 static_argnums,則快取的程式碼將僅用於標記為靜態的引數的相同值。如果它們中的任何一個發生更改,則會發生重新編譯。如果有許多值,那麼您的程式可能會花費比逐個執行操作更多的時間進行編譯。

避免在迴圈或其他 Python 範圍內定義的臨時函式上呼叫 jax.jit()。在大多數情況下,JAX 將能夠在後續對 jax.jit() 的呼叫中使用已編譯、快取的函式。但是,由於快取依賴於函式的雜湊值,因此當重新定義等效函式時,它會變得有問題。這將導致迴圈中每次不必要的編譯

from functools import partial

def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()

print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()

print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()
jit called in a loop with partials:
350 ms ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
343 ms ± 6.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
3.81 ms ± 7.28 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)