JAX PRNG 設計#

我們希望 PRNG 設計能夠

  1. 具有表現力,方便使用且不會限制使用者編寫具有完全符合他們需求的數值程式能力,

  2. 實現可重現的程式執行,且與後端無關,

  3. 具有語意,這些語意對於 @jit 編譯邊界和裝置後端是不變的

  4. 能夠向量化以使用 SIMD 硬體產生陣列值,

  5. 是可平行化的,因為它不會在隨機函式呼叫之間增加排序約束,否則這些呼叫將沒有資料依賴性,

  6. 可擴展到多副本、多核心和分散式計算

  7. 符合 JAX 和 XLA 語意和設計理念(最終由其他實際考量所驅動)。

作為這些的必然結果,我們認為設計應該是函數式的。另一個必然結果是,至少在目前的硬體限制下,我們將在軟體中完成 PRNG。

總之 JAX PRNG = Threefry 計數器 PRNG + 函數式陣列導向分割模型

目錄#

三種程式設計模型和玩具範例程式#

這是一個玩具範例,展示了像 Numpy 程式中常用的具狀態全域 PRNG

def foo(): return bar() + baz()
def bar(): return rand(RNG, (3, 4))
def baz(): return rand(RNG, (3, 4))
def main():
  global RNG
  RNG = RandomState(0)
  return foo()

為了在這裡實現可重現性,我們需要控制 bar() 和 baz() 的評估順序,即使它們之間沒有明確的資料依賴性。這種源於可重現性 (#2) 的排序要求違反了可平行化性 (#5),並且不符合 JAX 或 XLA 的函數式語意 (#6),在這些語意中,子表達式可以以任何順序評估。即使我們不要求可重現性,因此允許任何評估順序,跨呼叫的平行化 (#5) 仍然會因為需要更新共享狀態而變得困難。此外,由於相同的 PRNG 狀態需要在 Python 和任何編譯後的程式碼中存取和維護,因此此模型可能會導致工程挑戰,以實現編譯不變性 (#3) 和擴展到多個副本 (#6)。最後,表現力受到限制 (#1),因為 foo() 無法在不影響其自身(隱含)PRNG 狀態的情況下呼叫 bar() 或 baz()。

模型是否支援向量化 (#4) 取決於一些額外的細節。在 Numpy 中,PRNG 向量化受到序列等效保證的限制

In [1]: rng = np.random.RandomState(0)

In [2]: rng.randn(2)
Out[2]: array([1.76405235, 0.40015721])

In [3]: rng = np.random.RandomState(0)

In [4]: np.stack([rng.randn() for _ in range(2)])
Out[4]: array([1.76405235, 0.40015721])

為了允許在產生陣列的原始 PRNG 函式呼叫中進行向量化 (#4)(例如,使用形狀引數呼叫 rand()),我們放棄了此序列等效保證。此向量化可以由本節中討論的三種程式設計模型中的任何一種支援,儘管它激發了根據下一節中描述的基於計數器的 PRNG 進行實作。

具狀態的 PRNG 使用者程式設計模型並不理想。這是一個函數式模型的範例,但缺少我們稱之為分割的關鍵要素

def foo(rng_1):
   y, rng_2 = baz(rng_1)
   z, rng_3 = bar(rng_2)
   return y + z, rng_3

def bar(x, rng):
  val, new_rng = rand(rng, (3, 4))
  return val, new_rng

def baz(x, rng):
  val, new_rng = rand(rng, (3, 4))
  return val, new_rng

def main():
  foo(RandomState(0))

此模型明確地將 PRNG 狀態穿梭於所有產生隨機值的函式(原始或非原始):也就是說,每個隨機函式都必須接受並傳回狀態。現在在 foo() 中呼叫 baz() 和呼叫 bar() 之間存在明確的資料依賴性,因此資料流(以及排序)變得明確,並且符合 JAX 現有的語意 (#7),這與先前的模型不同。這種明確的線程處理也可以使語意對於編譯邊界 (#3) 保持不變。

明確的線程處理對於程式設計師來說很不方便。但更糟糕的是,它實際上並沒有提高表現力 (#1):foo() 仍然無法在保持自身 PRNG 狀態的情況下呼叫 bar() 或 baz()。在不了解其呼叫者或它們呼叫的子程序的情況下,函式必須在任何地方防禦性地傳入和傳回 rng 狀態。此外,它也沒有改善平行化 (#5) 或擴展到多個副本 (#6) 的前景,因為一切仍然是循序的,即使排序在函數式程式設計意義上變得明確。

簡而言之,透過明確地線程處理狀態使程式碼成為函數式,還不足以實現我們的表現力 (#1) 和效能 (#5, #6) 目標。

先前兩種模型的關鍵問題在於排序過多。為了減少順序依賴性,我們使用函數式可分割 PRNG。分割是一種將新的 PRNG 狀態「fork」到兩個 PRNG 狀態的機制,同時保持通常理想的 PRNG 屬性(這兩個新串流在計算上是可平行化的,並且產生獨立的隨機值,即它們的行為類似於多串流)。

def foo(rng_1):
   rng_2, rng_3 = split(rng_1, 2)
   return bar(rng_2) + baz(rng_3)

def bar(x, rng):
  return rand(rng, (3, 4))

def baz(x, rng):
  return rand(rng, (3, 4))

def main():
  foo(RandomState(0))

注意事項

  1. bar() 和 baz() 的呼叫之間沒有順序依賴性,它們可以以任何順序評估,而不會影響結果的值,這解決了剩餘的效能目標 (#5, #6),

  2. 函式不需要傳回 PRNG 的更新版本,並且可以直接呼叫隨機子程序,而不會影響現有的 PRNG 狀態,從而提高了其他函數式模型的表現力 (#1)。

範例沒有顯示這一點,但作為選擇 (2) 的結果,推進 PRNG 狀態的唯一方法是呼叫 split()。也就是說,我們有兩種方法可以實現 (1),它們的不同之處在於它們是否讓使用者程式承擔明確呼叫 split() 的負擔,如上述範例所示,或者讓使用者程式承擔明確線程處理的負擔。我們更喜歡前者,即具有明確分割的版本,因為我們可以輕鬆地根據它實作明確線程處理版本。

設計#

我們可以使用基於計數器的 PRNG 設計,特別是 Threefry 雜湊函式,如平行隨機數:像 1、2、3 一樣簡單中所述。我們使用計數器來實現高效的向量化:對於給定的金鑰,我們可以透過將雜湊函式映射到整數範圍 [k + 1, …, k + sample_size] 來以向量化方式產生值陣列。我們將金鑰與雜湊函式一起使用來實現可分割的 PRNG:也就是說,分割是一種從現有金鑰產生兩個新金鑰的方法。

type Sample = Int256
type Key = Sample  -- important identification for splitting
type Count = Int32

hash :: Key -> Count -> Int256  -- output type equal to Key and Sample

split :: Key -> (Key, Key)
split key = (hash key 0, hash key 1)

draw_samples :: Key -> Int -> [Sample]
draw_samples key n = map (hash key) [1..n]

令人驚訝的是,繪製樣本與分割非常相似!關鍵在於輸出型別的差異(即使型別被識別):在一個案例中,該值將用於形成感興趣的隨機樣本(例如,將隨機位元轉換為表示隨機常態的浮點數),而在另一個案例中,該值將用作進一步雜湊的金鑰。

雜湊函式引數(型別為 Key 和 Count)的不對稱性在於,後者很容易且計算成本低廉地推進任意量,因為我們只需要增加整數值,而前者僅透過雜湊推進。這就是為什麼我們使用計數參數進行向量化的原因。

更實際的範例使用者程式#

這是在主機上訓練迴圈的樣子,當步驟需要 PRNG 時(可能用於 dropout 或 VAE 訓練)

rng = lax.rng.new_rng()
for i in xrange(num_steps):
  rng, rng_input = lax.rng.split(rng)
  params = compiled_update(rng_input, params, next(batches))

請注意,我們讓使用者承擔了明確分割 rng 的責任,但 rng 完全不需要從程式碼中傳回。

以下是如何將此 PRNG 模型與 stax 神經網路建構器庫一起使用以實現 dropout

def Dropout(rate, mode='train'):
  def init_fun(input_shape):
    return input_shape, ()
  def apply_fun(rng, params, inputs):
    if mode == 'train':
      keep = lax.random.bernoulli(rng, rate, inputs.shape)
      return np.where(keep, inputs / rate, 0)
    else:
      return inputs
  return init_fun, apply_fun

這裡的 rng 值只是用於雜湊的金鑰,而不是特殊物件。rng 參數傳遞給每個 apply_fun,因此需要在具有分割的序列和並行組合器中處理它

def serial(*layers):
  init_funs, apply_funs = zip(*layers)
  def init_fun(input_shape):
    ...
  def apply_fun(rng, params, inputs):
    rngs = split(rng, len(layers))
    for rng, param, apply_fun in zip(rngs, params, apply_funs):
      inputs = apply_fun(rng, param, inputs)
    return inputs
  return init_fun, apply_fun

def parallel(*layers):
  init_funs, apply_funs = zip(*layers)
  def init_fun(input_shape):
    ...
  def apply_fun(rng, params, inputs):
    rngs = split(rng, len(layers))
    return [f(r, p, x) for f, r, p, x in zip(apply_funs, rngs, params, inputs)]
  return init_fun, apply_fun

這裡我們使用 split 的簡單擴充版本,它可以產生多個副本。

權衡和替代方案#

  1. 我們沒有利用任何裝置硬體 PRNG

    • 我們目前沒有足夠的控制權來控制所有後端的硬體 PRNG 狀態。

    • 即使我們這樣做了,它也會依賴於後端,我們可能必須在隨機呼叫之間引入順序依賴性,以確保確定性排序,從而確保可重現性。

    • 我們不知道有任何工作負載會讓軟體 PRNG 成為瓶頸。

    • 我們可以考慮提供一個額外的 API,允許想要放棄其他期望(例如嚴格的可重現性)的使用者存取硬體 PRNG。

  2. 我們放棄了序列等效保證,其中在一次呼叫中建立隨機陣列會產生與一次隨機元素一個接一個地建立扁平化陣列相同的值。

    • 此屬性可能與向量化(高優先順序)不相容。

    • 我們不知道有任何使用者或範例認為此屬性很重要。

    • 使用者可以在此 API 之上撰寫一個層,以提供此保證。

  3. 我們無法完全遵循 numpy.random API。