偽隨機數#

如果所有科學論文的結果都因錯誤的 rand() 而受到質疑,而必須從圖書館書架上消失,那麼每個書架上都會出現一個像你的拳頭一樣大的空隙。 - Numerical Recipes

在本節中,我們將重點關注 jax.random 和偽隨機數生成 (PRNG);也就是說,演算法生成數字序列的過程,其屬性近似於從適當分佈中採樣的隨機數序列的屬性。

PRNG 生成的序列並非真正隨機,因為它們實際上是由其初始值決定的,初始值通常稱為 seed,並且隨機採樣的每個步驟都是某些 state 的確定性函數,該 state 從一個樣本傳遞到下一個樣本。

偽隨機數生成是任何機器學習或科學計算框架的重要組成部分。一般來說,JAX 努力與 NumPy 相容,但偽隨機數生成是一個顯著的例外。

為了更好地理解 JAX 和 NumPy 在隨機數生成方面採用的方法之間的差異,我們將在本節中討論這兩種方法。

NumPy 中的隨機數#

NumPy 中的 numpy.random 模組原生支援偽隨機數生成。在 NumPy 中,偽隨機數生成基於全域 state,可以使用 numpy.random.seed() 將其設定為確定性的初始條件。

import numpy as np
np.random.seed(0)

重複調用 NumPy 的有狀態偽隨機數產生器 (PRNG) 會改變全域狀態,並產生偽隨機數流

print(np.random.random())
print(np.random.random())
print(np.random.random())
0.5488135039273248
0.7151893663724195
0.6027633760716439

在底層,NumPy 使用 Mersenne Twister PRNG 來驅動其偽隨機函數。PRNG 的週期為 \(2^{19937}-1\),並且在任何時候都可以用 624 個 32 位元無號整數和一個位置來描述,該位置指示已使用了多少「熵」。

您可以使用以下命令檢查狀態的內容。

def print_truncated_random_state():
  """To avoid spamming the outputs, print only part of the state."""
  full_random_state = np.random.get_state()
  print(str(full_random_state)[:460], '...')

print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...

state 會在每次調用隨機函數時更新

np.random.seed(0)
print_truncated_random_state()
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...

NumPy 允許您在單個函數調用中採樣單個數字或整個數字向量。例如,您可以透過執行以下操作,從均勻分佈中採樣 3 個純量向量

np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135  0.71518937 0.60276338]

NumPy 提供循序等效保證,這表示連續個別採樣 N 個數字或採樣 N 個數字的向量會產生相同的偽隨機序列

np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135  0.71518937 0.60276338]
all at once:  [0.5488135  0.71518937 0.60276338]

JAX 中的隨機數#

JAX 的隨機數生成在重要方面與 NumPy 的不同,因為 NumPy 的 PRNG 設計使其難以同時保證許多理想的屬性。具體來說,在 JAX 中,我們希望 PRNG 生成是

  1. 可重現的,

  2. 可平行化的,

  3. 可向量化的。

我們將在下面討論原因。首先,我們將重點關注基於全域狀態的 PRNG 設計的影響。考慮以下程式碼

import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())
1.9791922366721637

函數 foo 將從均勻分佈中採樣的兩個純量相加。

只有當我們假設 bar()baz() 的執行順序可預測時,此程式碼的輸出才能滿足要求 #1。這在 NumPy 中不是問題,NumPy 始終按照 Python 解譯器定義的順序評估程式碼。但是,在 JAX 中,這更成問題:為了有效執行,我們希望 JIT 編譯器可以自由地重新排序、省略和融合我們定義的函數中的各種操作。此外,在多裝置環境中執行時,每個進程都需要同步全域狀態,這會阻礙執行效率。

顯式隨機狀態#

為了避免這些問題,JAX 避免了隱含的全域隨機狀態,而是透過隨機 key 顯式追蹤狀態

from jax import random

key = random.key(42)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 42]

注意

本節使用由 jax.random.key() 產生的新型別化 PRNG 金鑰,而不是由 jax.random.PRNGKey() 產生的舊型原始 PRNG 金鑰。如需詳細資訊,請參閱 JEP 9263:類型化金鑰 & 可插拔 RNG

金鑰是一個陣列,具有對應於所使用特定 PRNG 實作的特殊 dtype;在預設實作中,每個金鑰都由一對 uint32 值支援。

金鑰有效地替代了 NumPy 的隱藏狀態物件,但我們將其顯式傳遞給 jax.random() 函數。重要的是,隨機函數會消耗金鑰,但不會修改它:將相同的金鑰物件饋送到隨機函數始終會產生相同的樣本。

print(random.normal(key))
print(random.normal(key))
-0.028304616
-0.028304616

重複使用相同的金鑰,即使使用不同的 random API,也可能導致相關的輸出,這通常是不希望發生的。

經驗法則是:永遠不要重複使用金鑰(除非您想要相同的輸出)。

JAX 使用現代 Threefry 基於計數器的 PRNG,它是可分割的。也就是說,它的設計允許我們將 PRNG 狀態分支到新的 PRNG 中,以用於平行隨機生成。為了生成不同且獨立的樣本,您必須在將金鑰傳遞給隨機函數之前,顯式地 split() 金鑰

for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.
draw 0: 0.6057640314102173
draw 1: -0.21089035272598267
draw 2: -0.3948981463909149

(此處不需要調用 del,但我們這樣做是為了強調一旦消耗金鑰就不應重複使用。)

jax.random.split() 是一個確定性函數,它將一個 key 轉換為多個獨立的(在偽隨機性意義上)金鑰。我們保留其中一個輸出作為 new_key,並且可以安全地使用唯一的額外金鑰(稱為 subkey)作為隨機函數的輸入,然後永遠丟棄它。如果您想從常態分佈中取得另一個樣本,您將再次分割 key,依此類推:關鍵是您永遠不會重複使用相同的金鑰。

我們將 split(key) 的哪個部分稱為 key,以及哪個部分稱為 subkey 並不重要。它們都是具有同等地位的獨立金鑰。金鑰/子金鑰命名慣例是一種典型的使用模式,有助於追蹤金鑰的消耗方式:子金鑰旨在立即被隨機函數消耗,而金鑰則被保留以在稍後產生更多隨機性。

通常,上面的範例會簡潔地寫成

key, subkey = random.split(key)

這會自動丟棄舊金鑰。值得注意的是,split() 可以根據您的需要建立任意數量的金鑰,而不僅僅是 2 個

key, *forty_two_subkeys = random.split(key, num=43)

缺乏循序等效性#

NumPy 和 JAX 的隨機模組之間的另一個差異與上面提到的循序等效保證有關。

與 NumPy 一樣,JAX 的隨機模組也允許採樣數字向量。但是,JAX 不提供循序等效保證,因為這樣做會干擾 SIMD 硬體上的向量化(上述要求 #3)。

在下面的範例中,使用三個子金鑰個別採樣常態分佈中的 3 個值,會得到與提供單個金鑰並指定 shape=(3,) 不同的結果

key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
individually: [0.07592554 0.60576403 0.4323065 ]
all at once:  [-0.02830462  0.46713185  0.29570296]

缺乏循序等效性讓我們可以更有效率地編寫程式碼;例如,我們可以不透過循序迴圈生成上面的 sequence,而是使用 jax.vmap() 以向量化方式計算相同的結果

import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))
vectorized: [0.07592554 0.60576403 0.4323065 ]

下一步#

有關 JAX 隨機數的更多資訊,請參閱 jax.random 模組的文件。如果您對 JAX 隨機數產生器設計的詳細資訊感興趣,請參閱 JAX PRNG 設計