jax.random
模組#
用於虛擬隨機數產生的工具。
jax.random
套件提供許多常式,用於決定性地產生虛擬隨機數序列。
基本用法#
>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
... key, subkey = jax.random.split(key)
... params = compiled_update(subkey, params, next(batches))
PRNG 金鑰#
與 NumPy 和 SciPy 使用者可能習慣的具狀態虛擬隨機數產生器 (PRNG) 不同,JAX 隨機函數都要求將明確的 PRNG 狀態作為第一個引數傳遞。隨機狀態由我們稱為金鑰的特殊陣列元素型別描述,通常由 jax.random.key()
函數產生
>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
然後此金鑰可用於任何 JAX 的隨機數產生常式中
>>> random.uniform(key)
Array(0.947667, dtype=float32)
請注意,使用金鑰不會修改它,因此重複使用相同的金鑰將導致相同的結果
>>> random.uniform(key)
Array(0.947667, dtype=float32)
如果您需要新的隨機數,可以使用 jax.random.split()
來產生新的子金鑰
>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.00729382, dtype=float32)
注意
具有元素型別(例如上面的 key<fry>
)的型別化金鑰陣列是在 JAX v0.4.16 中引入的。在此之前,金鑰通常以 uint32
陣列表示,其最後一個維度代表金鑰的位元級表示。
兩種形式的金鑰陣列仍然可以使用 jax.random
模組建立和使用。新的型別化金鑰陣列使用 jax.random.key()
建立。舊版的 uint32
金鑰陣列使用 jax.random.PRNGKey()
建立。
若要在兩者之間轉換,請使用 jax.random.key_data()
和 jax.random.wrap_key_data()
。當與 JAX 外部的系統介接(例如,將陣列匯出為可序列化格式),或將金鑰傳遞給假設舊版格式的基於 JAX 的函式庫時,可能需要舊版金鑰格式。
否則,建議使用型別化金鑰。相對於型別化金鑰,舊版金鑰的注意事項包括
它們具有額外的尾隨維度。
它們具有數值資料型別 (
uint32
),允許通常不應在金鑰上執行的操作,例如整數算術。它們不攜帶關於 RNG 實作的資訊。當舊版金鑰傳遞給
jax.random
函數時,全域設定配置決定 RNG 實作(請參閱下方的「進階 RNG 設定」)。
若要深入瞭解此升級以及金鑰型別的設計,請參閱 JEP 9263。
進階#
設計與背景#
TLDR:JAX PRNG = Threefry 計數器 PRNG + 功能性陣列導向的 分割模型
請參閱 docs/jep/263-prng.md 以取得更多詳細資訊。
總之,除其他要求外,JAX PRNG 的目標是
確保可重現性,
在向量化(產生陣列值)和多副本、多核心計算方面都能良好地平行化。特別是,它不應使用隨機函數呼叫之間的排序約束。
進階 RNG 設定#
JAX 提供多種 PRNG 實作。可以使用 jax.random.key
的選用 impl
關鍵字引數選擇特定的實作。當沒有 impl
選項傳遞給 key
建構函式時,實作由全域 jax_default_prng_impl
設定旗標決定。可用實作的字串名稱為
"threefry2x32"
(預設):基於 Threefry 雜湊函數變體的計數器式 PRNG,如 Salmon 等人,2011 年的論文所述。"rbg"
和"unsafe_rbg"
(實驗性):建立在 XLA 的隨機位元產生器 (RBG) 演算法之上的 PRNG。"rbg"
使用 XLA RBG 進行隨機數產生,而對於金鑰衍生(如jax.random.split
和jax.random.fold_in
中),它使用與"threefry2x32"
相同的方法。"unsafe_rbg"
將 XLA RBG 用於產生和金鑰衍生。
這些實驗性方案產生的隨機數尚未經過經驗隨機性測試(例如 BigCrush)。
"unsafe_rbg"
中的金鑰衍生也尚未經過實證測試。名稱強調「不安全」,因為金鑰衍生品質和產生品質尚未被很好地理解。此外,
"rbg"
和"unsafe_rbg"
在jax.vmap
下的行為異常。當 vmapping 隨機函數處理一批金鑰時,其輸出值可能與其在相同金鑰上的真實映射不同。相反地,在vmap
下,整批輸出隨機數僅從輸入金鑰批次中的第一個金鑰產生。例如,如果keys
是 8 個金鑰的向量,則jax.vmap(jax.random.normal)(keys)
等於jax.random.normal(keys[0], shape=(8,))
。這種特殊性反映了 XLA RBG 有限批次處理支援的變通方法。
使用預設 RNG 替代方案的原因包括
在 TPU 上編譯可能很慢。
在 TPU 上執行相對較慢。
自動分割
為了使 jax.jit
有效地自動分割產生分片隨機數陣列(或金鑰陣列)的函數,所有 PRNG 實作都需要額外的旗標
對於
"threefry2x32"
和"rbg"
金鑰衍生,設定jax_threefry_partitionable=True
。對於
"unsafe_rbg"
和"rbg"
隨機產生”,設定 XLA 旗標--xla_tpu_spmd_rng_bit_generator_unsafe=1
。
XLA 旗標可以使用 XLA_FLAGS
環境變數設定,例如 XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
。
有關 jax_threefry_partitionable
的更多資訊,請參閱 https://jax.dev.org.tw/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers
摘要
屬性 |
Threefry |
Threefry* |
rbg |
unsafe_rbg |
rbg** |
unsafe_rbg** |
---|---|---|---|---|---|---|
TPU 上最快 |
✅ |
✅ |
✅ |
✅ |
||
有效率地可分片 (使用 pjit) |
✅ |
✅ |
✅ |
|||
跨分片相同 |
✅ |
✅ |
✅ |
✅ |
||
跨 CPU/GPU/TPU 相同 |
✅ |
✅ |
||||
精確的 |
✅ |
✅ |
(*):設定 jax_threefry_partitionable=1
(**):設定 XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1
API 參考#
金鑰建立與操作#
|
使用整數種子建立虛擬隨機數產生器 (PRNG) 金鑰。 |
|
復原 PRNG 金鑰陣列底層的金鑰資料位元。 |
|
將金鑰資料位元陣列包裝到 PRNG 金鑰陣列中。 |
|
將資料摺疊到 PRNG 金鑰中以形成新的 PRNG 金鑰。 |
|
透過新增前導軸,將一個 PRNG 金鑰分割成 num 個新金鑰。 |
|
複製金鑰以供重複使用 |
|
使用整數種子建立舊版 PRNG 金鑰。 |
隨機取樣器#
|
從單位 Lp 球均勻取樣。 |
|
取樣具有給定形狀和平均值的 Bernoulli 隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的 Beta 隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的二項式隨機值。 |
|
以無號整數的形式取樣均勻位元。 |
|
從類別分佈中取樣隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的 Cauchy 隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的卡方隨機值。 |
|
從給定陣列產生隨機樣本。 |
|
取樣具有給定形狀和浮點數資料型別的 Dirichlet 隨機值。 |
|
從雙邊 Maxwell 分佈取樣。 |
|
取樣具有給定形狀和浮點數資料型別的指數隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的 F 分佈隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的 Gamma 隨機值。 |
|
從廣義常態分佈取樣。 |
|
取樣具有給定形狀和浮點數資料型別的幾何隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的 Gumbel 隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的 Laplace 隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的對數 Gamma 隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的 Logistic 隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的對數常態隨機值。 |
|
從單邊 Maxwell 分佈取樣。 |
|
取樣具有給定平均值和共變異數的多元常態隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的標準常態隨機值。 |
|
從正交群 O(n) 均勻取樣。 |
|
取樣具有給定形狀和浮點數資料型別的 Pareto 隨機值。 |
|
傳回隨機排列的陣列或範圍。 |
|
取樣具有給定形狀和整數資料型別的 Poisson 隨機值。 |
|
從 Rademacher 分佈取樣。 |
|
取樣 [minval, maxval) 範圍內具有給定形狀/資料型別的均勻隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的 Rayleigh 隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的 Student's t 隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的三角隨機值。 |
|
取樣具有給定形狀和資料型別的截斷標準常態隨機值。 |
|
取樣 [minval, maxval) 範圍內具有給定形狀/資料型別的均勻隨機值。 |
|
取樣具有給定形狀和浮點數資料型別的 Wald 隨機值。 |
|
從 Weibull 分佈取樣。 |