JAX 中的副作用排序#

sharadmv@ 2022 年 5 月 9 日

概述#

當我們撰寫 JAX 程式碼時,即使在底層,JAX 及其執行階段可能會在背景中非同步執行它,我們通常可以假裝我們正在撰寫單執行緒、急切執行的 Python。只要我們撰寫純粹(無副作用)的程式碼,這些效能最佳化通常對我們來說是不可見的,並且不會干擾我們的單執行緒心理模型。非同步執行非常棒 – 我們可以獲得高效能、平行的程式碼,而無需考慮它!

然而,在存在副作用的情況下,這種錯覺開始瓦解,我們心理模型中的裂縫開始顯現。具體來說,當我們考慮副作用發生的順序時,這些差異就會顯現出來。

在這份設計筆記中,我們探討 JAX 的執行模型與副作用排序之間的相互作用。我們也提供一種強制執行副作用「單執行緒」排序的方法。

背景#

當我們撰寫以下 Python 程式碼時

def f():
  print("hello")
  return 2
def g():
  print("world")
  return 3
f()
g()

我們預期 "hello" 會在 "world" 之前列印。這似乎很明顯,但請考慮以下 JAX 程式碼

@partial(jax.jit, device=<device 0>)
def f():
  return 2

@partial(jax.jit, device=<device 1>)
def g():
  return 3
f()
g()

在許多情況下,JAX 將平行執行 fg,將計算分派到不同的執行緒 – g 實際上可能會在 f 之前執行。平行執行是一種很好的效能最佳化,尤其是在複製到裝置和從裝置複製的成本很高時(有關更多詳細資訊,請參閱非同步分派筆記)。然而,在實務上,我們通常不需要考慮非同步分派,因為我們正在撰寫純函式,並且只關心函式的輸入和輸出 – 我們自然會封鎖未來的數值。

然而,現在想像一下,我們有一個 jax.print 函式,它可以在 JIT 編譯的 JAX 函式內部運作(host_callback.id_print 就是一個例子)。讓我們回到先前的範例,除了混合了列印。

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")
  return 2

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")
  return 3
f()
g()

由於非同步分派,我們實際上可能會看到 "world""hello" 之前列印。列印副作用的重新排序破壞了單執行緒執行模型的錯覺。

副作用可以「揭示」亂序執行的另一個範例是當我們編譯 JAX 程式時。請考慮以下 JAX 程式碼

@jax.jit
def f(x):
  jax.print("hello")
  jax.print("world")
  return x

即使在 Python 中,我們在 "world" 列印之前撰寫了 "hello" 列印,但像 XLA 這樣的編譯器可以自由地重新排序它們,因為列印之間沒有明確的資料相依性。

動機#

我們希望支援「排序」效果。當我們說排序時,我們指的是效果發生的順序與我們執行單執行緒 Python 程式時的順序相同。這是我們的主要目標。在存在明確平行性(如 pmap 或使用者執行緒)的情況下,我們不需要維護此行為,但至少在使用者未明確要求平行性的情況下,我們希望保留單執行緒排序。

在我們深入探討更多細節之前,讓我們先退後一步,問問我們自己,為了效能而重新排序效果是否可以接受,反之亦然,我們是否需要強制執行效果的排序?在某些情況下,我們不需要排序。也許某些副作用不應對 JAX 程式的效能產生不利影響。然而,對於其他副作用,我們可能希望強制執行單執行緒程式順序,以便使用者不會獲得違反直覺的行為。請考慮記錄效果。

@jax.jit
def f(x, y):
  log_value(x)
  log_value(y)
f(1, 2)

如果 log 正在變更全域列表,我們可能會預期我們在新增 y 之前新增 x。對於更嚴格的效果,我們可能希望選擇對效果進行排序。

強制執行排序效果#

我們用來強制執行計算排序的主要工具是資料相依性。簡而言之,如果函式 g 的輸入是函式 f 的輸出,則 f 必須在 g 之前執行。

然而,我們可能有像列印這樣的副作用,它們根本沒有輸入,因此我們無法天真地對它們進行排序。因此,我們使用 Token 作為將人為資料相依性注入計算的一種方式。

什麼是 Token?Token 只是可以穿梭於計算中的虛擬值。透過在多個計算中穿梭相同的 Token,我們強制它們必須以特定順序發生。讓我們以先前的列印範例為例,看看在混合中使用 Token 會是什麼樣子

@jax.jit
def f(token, x):
  token = jax.print(token, "hello")
  token = jax.print(token, "world")
  return token, x

如果我們重寫 jax.print 以接收和傳回 Token,我們現在已經對兩個列印進行排序,因為第二個列印的輸入取決於第一個列印的輸出。token 的實際值可以是任何值,但我們將在實務中看到 Token 對使用者來說是不可見的。

執行階段 Token 與編譯器 Token#

在這裡,我們實際上將開始討論實作細節。在實務上,我們需要兩種不同類型的 Token 來排序效果:每種類型對應上述重新排序來源之一。我們需要執行階段 Token 來排序非同步分派的副作用計算,並且我們需要編譯器 Token 來排序計算內的效果。

在實務上,我們的計算將被重寫成如下所示

@jax.jit
def f(runtime_token, x):
  compiler_token = new_compiler_token()
  compiler_token = jax.print(compiler_token, "hello")
  compiler_token = jax.print(compiler_token, "world")
  return runtime_token, x

請注意,執行階段 Token 僅在 JIT 邊界使用,而編譯器 Token 僅在編譯後的程式碼中使用。編譯器 Token 在「降低」期間建立(我們將 Python 程式碼轉換為較低層級的表示形式,如 HLO 或 StableHLO),但執行階段 Token 需要在 Python 中管理,因為它們正在穿梭於 JIT 編譯的函式中。

此外,請注意執行階段 Token 與編譯器 Token 是「斷開連接」的,這表示它們之間沒有資料相依性。如果我們將失去兩個分派的函式呼叫主體之間的資料相依性,這可能會很危險。但是,如果我們假設「嚴格執行」– 即,只有當分派函式的所有輸入都準備就緒且其所有輸出都將同時準備就緒時,分派函式才會開始執行 – 我們可以安全地建立新的編譯器 Token 並傳回非輸出相依的執行階段 Token。

管理執行階段 Token#

為了代表使用者管理執行階段 Token,我們需要掛鉤到 JAX 的分派機制中。每當我們呼叫 JIT 編譯的函式時,我們最終都會在如下所示的函式中結束

def _execute(compiled_computation, *args):
  outputs = compiled_computation.execute(*args)
  return outputs

此時,我們需要將執行階段 Token「注入」到計算中,並從計算的輸出中「提取」它們

def _execute(compiled_computation, *args):
  runtime_token = get_runtime_token() # Grab global token
  runtime_token, *outputs = compiled_computation.execute(runtime_token, *args)
  update_runtime_token(runtime_token) # Update global token
  return outputs

runtime_token 到底是什麼?嗯,我們需要能夠將其傳遞到 compiled_computation 中,這表示它需要是某種類型的陣列(目前,因為在編譯後的 JAX 程式碼內外沒有共用的 Token 表示形式)。在實務上,我們可以利用 (0,) 形狀的陣列將額外負荷降至最低。

我們還需要考慮多個裝置的使用案例,例如,第一個範例,我們首先在裝置 0 上呼叫 JIT 編譯的函式,然後在裝置 1 上呼叫一個函式。在這種情況下,我們還需要將從第一個計算傳回的執行階段 Token(位於裝置 0 上)複製到裝置 1,以便我們可以將其傳遞到第二個計算中。如果兩個後續計算共用相同的裝置,則不需要此複製。

新增編譯器 Token#

當我們將 Python 程式碼降低為 HLO 或 StableHLO 時,我們需要在計算開始時建立 Token,並確保當我們有需要排序的副作用計算時,它們可用。副作用計算將 Token 作為輸入並將其作為輸出傳回。

此 Token 執行緒的實作涉及升級 JAX 降低機制以自動執行此簿記。主要挑戰包括處理高階 primitives,如呼叫 primitives 和控制流程 primitives。我們將不會在這份設計筆記中詳細介紹如何處理這些 primitives。

封鎖輸出 Token#

為副作用計算新增對執行階段和編譯器 Token 的支援對於排序非常重要,但 Token 還有另一個微妙的使用案例,即封鎖副作用計算。即使我們不希望副作用計算被排序,我們可能仍然希望等待其完成。目前我們有 jax.block_until_ready,它會等待直到未來的值準備好結果。然而,對於副作用計算,我們可能有沒有傳回值但仍在執行副作用的函式。以下面的簡單範例為例

@jax.jit
def f():
  jax.print("hello world")
  return
f() # Executed asynchronously

此編譯後的計算不採用明確的輸入,也沒有明確的輸出。如果它是排序的列印效果,我們可以封鎖傳回的執行階段 Token,但是,當這是一個未排序的計算時,我們不會執行任何 Token 執行緒。當我們沒有輸出值可以呼叫 block_until_ready 時,我們如何等待 f() 完成執行?嗯,我們可以應用相同的 Token 策略,只是我們只傳回執行階段 Token,而不將它們作為輸入。這將為我們提供一個要封鎖的值,該值僅在 f() 完成執行後才會準備就緒。我們將這些 Token 稱為輸出 Token。我們最終得到一個如下所示的函式

@jax.jit
def f():
  jax.print("hello world")
  return new_runtime_token()
f() # Executed asynchronously

在底層,我們將以與管理執行階段 Token 相同的方式管理輸出 Token,但提供一種方法供使用者封鎖目前的輸出 Token 集。與執行階段 Token 不同,輸出 Token 需要特定於裝置。請考慮單一裝置的使用案例

@jax.jit
def f():
  jax.print("hello")

@jax.jit
def g():
  jax.print("world")

f()
g()

由於 f()g() 在相同的裝置上執行,因此封鎖 g() 的輸出 Token 實際上會封鎖 f(),因為(到目前為止!),JAX 執行階段不會交錯在相同裝置上執行的計算。如果這種情況發生變化,我們將必須修改整個設計。

但是,請考慮兩個裝置的使用案例

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")

f()
g()

在這裡,我們不想明確地排序 f()g(),但希望等待它們都完成。我們需要一個用於 f() 的輸出 Token 和一個用於 g() 的輸出 Token,我們將封鎖這兩個 Token

@partial(jax.jit, device=<device 0>)
def f():
  jax.print("hello")
  return new_runtime_token()

@partial(jax.jit, device=<device 1>)
def g():
  jax.print("world")
  return new_runtime_token()

t0 = f()
t1 = g()
block_until_ready((t0, t1))

因此,我們需要每個裝置的輸出 Token,以便我們可以避免排序不同裝置上的計算,同時提供封鎖副作用計算的能力。我們最終對 JAX 分派機制進行了以下(近似)變更

def _execute(compiled_computation, *args):
  output_token, *outputs = compiled_computation.execute(runtime_token, *args)
  update_output_token(output_token, compiled_computation.device)
  return outputs

我們還需要公開一個封鎖輸出 Token 的函式

def effects_barrier():
  output_token.block_until_ready()

請注意,封鎖輸出 Token 可能不是很常見,因為大多數 JAX 計算都會傳回一個要封鎖的值。然而,輸出 Token 對於測試和效能分析很有幫助,並且最好支援它,以便我們擁有一致且有凝聚力的效果系統。

更多細節#

  • 所有上述 Token 管理基礎架構都將是執行緒本機的。這表示每個使用者執行緒都將擁有自己的獨立執行階段 Token 串流。排序僅在使用者執行緒層級承諾。

  • 在實務上,我們每個效果有一個執行階段 Token。該效果的不同實例將被排序。這是為了避免排序可能彼此沒有任何關係的效果計算。從技術上講,這與我們強制執行單執行緒 Python 程式排序的最初目標背道而馳,但這是一個可以透過同時擁有「效果」特定 Token 和「全域」Token 來調整的權衡。