jax.random 模組#

用於虛擬隨機數產生的工具。

jax.random 套件提供許多常式,用於決定性地產生虛擬隨機數序列。

基本用法#

>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
...   key, subkey = jax.random.split(key)
...   params = compiled_update(subkey, params, next(batches))  

PRNG 金鑰#

與 NumPy 和 SciPy 使用者可能習慣的具狀態虛擬隨機數產生器 (PRNG) 不同,JAX 隨機函數都要求將明確的 PRNG 狀態作為第一個引數傳遞。隨機狀態由我們稱為金鑰的特殊陣列元素型別描述,通常由 jax.random.key() 函數產生

>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]

然後此金鑰可用於任何 JAX 的隨機數產生常式中

>>> random.uniform(key)
Array(0.947667, dtype=float32)

請注意,使用金鑰不會修改它,因此重複使用相同的金鑰將導致相同的結果

>>> random.uniform(key)
Array(0.947667, dtype=float32)

如果您需要新的隨機數,可以使用 jax.random.split() 來產生新的子金鑰

>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.00729382, dtype=float32)

注意

具有元素型別(例如上面的 key<fry>)的型別化金鑰陣列是在 JAX v0.4.16 中引入的。在此之前,金鑰通常以 uint32 陣列表示,其最後一個維度代表金鑰的位元級表示。

兩種形式的金鑰陣列仍然可以使用 jax.random 模組建立和使用。新的型別化金鑰陣列使用 jax.random.key() 建立。舊版的 uint32 金鑰陣列使用 jax.random.PRNGKey() 建立。

若要在兩者之間轉換,請使用 jax.random.key_data()jax.random.wrap_key_data()。當與 JAX 外部的系統介接(例如,將陣列匯出為可序列化格式),或將金鑰傳遞給假設舊版格式的基於 JAX 的函式庫時,可能需要舊版金鑰格式。

否則,建議使用型別化金鑰。相對於型別化金鑰,舊版金鑰的注意事項包括

  • 它們具有額外的尾隨維度。

  • 它們具有數值資料型別 (uint32),允許通常不應在金鑰上執行的操作,例如整數算術。

  • 它們不攜帶關於 RNG 實作的資訊。當舊版金鑰傳遞給 jax.random 函數時,全域設定配置決定 RNG 實作(請參閱下方的「進階 RNG 設定」)。

若要深入瞭解此升級以及金鑰型別的設計,請參閱 JEP 9263

進階#

設計與背景#

TLDR:JAX PRNG = Threefry 計數器 PRNG + 功能性陣列導向的 分割模型

請參閱 docs/jep/263-prng.md 以取得更多詳細資訊。

總之,除其他要求外,JAX PRNG 的目標是

  1. 確保可重現性,

  2. 在向量化(產生陣列值)和多副本、多核心計算方面都能良好地平行化。特別是,它不應使用隨機函數呼叫之間的排序約束。

進階 RNG 設定#

JAX 提供多種 PRNG 實作。可以使用 jax.random.key 的選用 impl 關鍵字引數選擇特定的實作。當沒有 impl 選項傳遞給 key 建構函式時,實作由全域 jax_default_prng_impl 設定旗標決定。可用實作的字串名稱為

  • "threefry2x32" (預設):基於 Threefry 雜湊函數變體的計數器式 PRNG,如 Salmon 等人,2011 年的論文所述。

  • "rbg""unsafe_rbg" (實驗性):建立在 XLA 的隨機位元產生器 (RBG) 演算法之上的 PRNG。

    • "rbg" 使用 XLA RBG 進行隨機數產生,而對於金鑰衍生(如 jax.random.splitjax.random.fold_in 中),它使用與 "threefry2x32" 相同的方法。

    • "unsafe_rbg" 將 XLA RBG 用於產生和金鑰衍生。

    這些實驗性方案產生的隨機數尚未經過經驗隨機性測試(例如 BigCrush)。

    "unsafe_rbg" 中的金鑰衍生也尚未經過實證測試。名稱強調「不安全」,因為金鑰衍生品質和產生品質尚未被很好地理解。

    此外,"rbg""unsafe_rbg"jax.vmap 下的行為異常。當 vmapping 隨機函數處理一批金鑰時,其輸出值可能與其在相同金鑰上的真實映射不同。相反地,在 vmap 下,整批輸出隨機數僅從輸入金鑰批次中的第一個金鑰產生。例如,如果 keys 是 8 個金鑰的向量,則 jax.vmap(jax.random.normal)(keys) 等於 jax.random.normal(keys[0], shape=(8,))。這種特殊性反映了 XLA RBG 有限批次處理支援的變通方法。

使用預設 RNG 替代方案的原因包括

  1. 在 TPU 上編譯可能很慢。

  2. 在 TPU 上執行相對較慢。

自動分割

為了使 jax.jit 有效地自動分割產生分片隨機數陣列(或金鑰陣列)的函數,所有 PRNG 實作都需要額外的旗標

  • 對於 "threefry2x32""rbg" 金鑰衍生,設定 jax_threefry_partitionable=True

  • 對於 "unsafe_rbg""rbg" 隨機產生”,設定 XLA 旗標 --xla_tpu_spmd_rng_bit_generator_unsafe=1

XLA 旗標可以使用 XLA_FLAGS 環境變數設定,例如 XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1

有關 jax_threefry_partitionable 的更多資訊,請參閱 https://jax.dev.org.tw/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers

摘要

屬性

Threefry

Threefry*

rbg

unsafe_rbg

rbg**

unsafe_rbg**

TPU 上最快

有效率地可分片 (使用 pjit)

跨分片相同

跨 CPU/GPU/TPU 相同

精確的 jax.vmap 處理金鑰

(*):設定 jax_threefry_partitionable=1

(**):設定 XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1

API 參考#

金鑰建立與操作#

key(seed, *[, impl])

使用整數種子建立虛擬隨機數產生器 (PRNG) 金鑰。

key_data(keys)

復原 PRNG 金鑰陣列底層的金鑰資料位元。

wrap_key_data(key_bits_array, *[, impl])

將金鑰資料位元陣列包裝到 PRNG 金鑰陣列中。

fold_in(key, data)

將資料摺疊到 PRNG 金鑰中以形成新的 PRNG 金鑰。

split(key[, num])

透過新增前導軸,將一個 PRNG 金鑰分割成 num 個新金鑰。

clone(key)

複製金鑰以供重複使用

PRNGKey(seed, *[, impl])

使用整數種子建立舊版 PRNG 金鑰。

隨機取樣器#

ball(key, d[, p, shape, dtype])

從單位 Lp 球均勻取樣。

bernoulli(key[, p, shape])

取樣具有給定形狀和平均值的 Bernoulli 隨機值。

beta(key, a, b[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的 Beta 隨機值。

binomial(key, n, p[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的二項式隨機值。

bits(key[, shape, dtype])

以無號整數的形式取樣均勻位元。

categorical(key, logits[, axis, shape])

從類別分佈中取樣隨機值。

cauchy(key[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的 Cauchy 隨機值。

chisquare(key, df[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的卡方隨機值。

choice(key, a[, shape, replace, p, axis])

從給定陣列產生隨機樣本。

dirichlet(key, alpha[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的 Dirichlet 隨機值。

double_sided_maxwell(key, loc, scale[, ...])

從雙邊 Maxwell 分佈取樣。

exponential(key[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的指數隨機值。

f(key, dfnum, dfden[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的 F 分佈隨機值。

gamma(key, a[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的 Gamma 隨機值。

generalized_normal(key, p[, shape, dtype])

從廣義常態分佈取樣。

geometric(key, p[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的幾何隨機值。

gumbel(key[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的 Gumbel 隨機值。

laplace(key[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的 Laplace 隨機值。

loggamma(key, a[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的對數 Gamma 隨機值。

logistic(key[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的 Logistic 隨機值。

lognormal(key[, sigma, shape, dtype])

取樣具有給定形狀和浮點數資料型別的對數常態隨機值。

maxwell(key[, shape, dtype])

從單邊 Maxwell 分佈取樣。

multivariate_normal(key, mean, cov[, shape, ...])

取樣具有給定平均值和共變異數的多元常態隨機值。

normal(key[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的標準常態隨機值。

orthogonal(key, n[, shape, dtype])

從正交群 O(n) 均勻取樣。

pareto(key, b[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的 Pareto 隨機值。

permutation(key, x[, axis, independent])

傳回隨機排列的陣列或範圍。

poisson(key, lam[, shape, dtype])

取樣具有給定形狀和整數資料型別的 Poisson 隨機值。

rademacher(key[, shape, dtype])

從 Rademacher 分佈取樣。

randint(key, shape, minval, maxval[, dtype])

取樣 [minval, maxval) 範圍內具有給定形狀/資料型別的均勻隨機值。

rayleigh(key, scale[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的 Rayleigh 隨機值。

t(key, df[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的 Student's t 隨機值。

triangular(key, left, mode, right[, shape, ...])

取樣具有給定形狀和浮點數資料型別的三角隨機值。

truncated_normal(key, lower, upper[, shape, ...])

取樣具有給定形狀和資料型別的截斷標準常態隨機值。

uniform(key[, shape, dtype, minval, maxval])

取樣 [minval, maxval) 範圍內具有給定形狀/資料型別的均勻隨機值。

wald(key, mean[, shape, dtype])

取樣具有給定形狀和浮點數資料型別的 Wald 隨機值。

weibull_min(key, scale, concentration[, ...])

從 Weibull 分佈取樣。