JEP 9263:具型別金鑰與可插拔 RNG#

Jake VanderPlas、Roy Frostig

2023 年 8 月

概觀#

展望未來,JAX 中的 RNG 金鑰將更具型別安全性和可自訂性。它不會以長度為 2 的 uint32 陣列表示單一 PRNG 金鑰,而是以具有特殊 RNG 資料型別的純量陣列表示,該資料型別滿足 jnp.issubdtype(key.dtype, jax.dtypes.prng_key)

目前,舊式 RNG 金鑰仍可使用 jax.random.PRNGKey() 建立

>>> key = jax.random.PRNGKey(0)
>>> key
Array([0, 0], dtype=uint32)
>>> key.shape
(2,)
>>> key.dtype
dtype('uint32')

從現在開始,可以使用 jax.random.key() 建立新式 RNG 金鑰

>>> key = jax.random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> key.shape
()
>>> key.dtype
key<fry>

此(純量形狀)陣列的行為與任何其他 JAX 陣列相同,只是其元素型別是金鑰(和相關中繼資料)。我們也可以建立非純量金鑰陣列,例如將 jax.vmap() 應用於 jax.random.key()

>>> key_arr = jax.vmap(jax.random.key)(jnp.arange(4))
>>> key_arr
Array((4,), dtype=key<fry>) overlaying:
[[0 0]
 [0 1]
 [0 2]
 [0 3]]
>>> key_arr.shape
(4,)

除了切換到新的建構函式之外,大多數與 PRNG 相關的程式碼應繼續如預期般運作。您可以繼續像以前一樣在 jax.random API 中使用金鑰;例如

# split
new_key, subkey = jax.random.split(key)

# random number generation
data = jax.random.uniform(key, shape=(5,))

但是,並非所有數值運算都適用於金鑰陣列。它們現在會故意引發錯誤

>>> key = key + 1  
Traceback (most recent call last):
TypeError: add does not accept dtypes key<fry>, int32.

如果由於某些原因您需要恢復底層緩衝區(舊式金鑰),可以使用 jax.random.key_data() 來執行此操作

>>> jax.random.key_data(key)
Array([0, 0], dtype=uint32)

對於舊式金鑰,key_data() 是一個恆等運算。

這對使用者意味著什麼?#

對於 JAX 使用者來說,此變更目前不需要任何程式碼變更,但我們希望您會發現升級值得,並切換到使用型別化金鑰。若要試用此功能,請將 jax.random.PRNGKey() 的使用替換為 jax.random.key()。這可能會在您的程式碼中引入一些類別的錯誤

  • 如果您的程式碼對金鑰執行不安全/不受支援的運算(例如索引、算術、轉置等;請參閱下方的「型別安全」章節),此變更將會捕捉到它。您可以更新您的程式碼以避免此類不受支援的運算,或使用 jax.random.key_data()jax.random.wrap_key_data() 以不安全的方式操作原始金鑰緩衝區。

  • 如果您的程式碼包含關於 key.shape 的明確邏輯,您可能需要更新此邏輯以考量到尾隨金鑰緩衝區維度不再是形狀的明確部分。

  • 如果您的程式碼包含關於 key.dtype 的明確邏輯,您將需要升級它以使用新的公開 API 來推論 RNG 資料型別,例如 dtypes.issubdtype(dtype, dtypes.prng_key)

  • 如果您呼叫尚不處理型別化 PRNG 金鑰的基於 JAX 的程式庫,您可以暫時使用 raw_key = jax.random.key_data(key) 來恢復原始緩衝區,但請保留 TODO 以在下游程式庫支援型別化 RNG 金鑰後移除此程式碼。

在未來的某個時間點,我們計劃棄用 jax.random.PRNGKey(),並要求使用 jax.random.key()

偵測新型別化的金鑰#

若要檢查物件是否為新型別化的 PRNG 金鑰,您可以使用 jax.dtypes.issubdtypejax.numpy.issubdtype

>>> typed_key = jax.random.key(0)
>>> jax.dtypes.issubdtype(typed_key.dtype, jax.dtypes.prng_key)
True
>>> raw_key = jax.random.PRNGKey(0)
>>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key)
False

PRNG 金鑰的型別註解#

新舊 PRNG 金鑰的建議型別註解都是 jax.Array。PRNG 金鑰根據其 dtype 與其他陣列區分開來,目前無法在型別註解中指定 JAX 陣列的資料型別。先前可以使用 jax.random.KeyArrayjax.random.PRNGKeyArray 作為型別註解,但這些註解在型別檢查下始終別名為 Any,因此 jax.Array 具有更高的特異性。

注意:jax.random.KeyArrayjax.random.PRNGKeyArray 已在 JAX 0.4.16 版中棄用,並在 JAX 0.4.24 版中移除.

JAX 程式庫作者的注意事項#

如果您維護基於 JAX 的程式庫,您的使用者也是 JAX 使用者。請注意,JAX 目前將繼續在 jax.random 中支援「原始」舊式金鑰,因此呼叫者可能會期望它們在任何地方都保持可接受。如果您希望在您的程式庫中要求新型別化的金鑰,則您可能需要使用類似以下程式碼的檢查來強制執行它們

from jax import dtypes

def ensure_typed_key_array(key: Array) -> Array:
  if dtypes.issubdtype(key.dtype, dtypes.prng_key):
    return key
  else:
    raise TypeError("New-style typed JAX PRNG keys required")

動機#

此變更的兩個主要動機因素是可自訂性和安全性。

自訂 PRNG 實作#

JAX 目前使用單一、全域設定的 PRNG 演算法運作。PRNG 金鑰是無符號 32 位元整數的向量,jax.random API 使用這些整數來產生虛擬隨機串流。任何較高秩的 uint32 陣列都被解譯為此類金鑰緩衝區的陣列,其中尾隨維度表示金鑰。

當我們引入替代 PRNG 實作時,此設計的缺點變得更加明顯,必須透過設定全域或本機組態旗標來選取這些實作。不同的 PRNG 實作具有不同大小的金鑰緩衝區,以及用於產生隨機位元的不同演算法。使用全域旗標判斷此行為容易出錯,尤其是在整個進程中使用多個金鑰實作時。

我們的新方法是將實作作為 PRNG 金鑰型別的一部分來攜帶,即使用金鑰陣列的元素型別。使用新的金鑰 API,以下是在預設 threefry2x32 實作(在純 Python 中實作並使用 JAX 編譯)和非預設 rbg 實作(對應於單一 XLA 隨機位元產生運算)下產生虛擬隨機值的範例

>>> key = jax.random.key(0, impl='threefry2x32')  # this is the default impl
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.947667  , 0.9785799 , 0.33229148], dtype=float32)

>>> key = jax.random.key(0, impl='rbg')
>>> key
Array((), dtype=key<rbg>) overlaying:
[0 0 0 0]
>>> jax.random.uniform(key, shape=(3,))
Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)

安全的 PRNG 金鑰使用#

原則上,PRNG 金鑰實際上僅用於支援少數運算,即金鑰衍生(例如分割)和隨機數產生。PRNG 旨在產生獨立的虛擬隨機數,前提是金鑰已正確分割且每個金鑰都使用一次。

以其他方式操作或使用金鑰資料的程式碼通常表示意外的錯誤,並且將金鑰陣列表示為原始 uint32 緩衝區使得沿這些方向的容易誤用。以下是我們在實際應用中遇到的一些誤用範例

金鑰緩衝區索引#

存取底層整數緩衝區使得以非標準方式嘗試衍生金鑰變得容易,有時會產生出乎意料的糟糕後果

# Incorrect
key = random.PRNGKey(999)
new_key = random.PRNGKey(key[1])  # identical to the original key!
# Correct
key = random.PRNGKey(999)
key, new_key = random.split(key)

如果此金鑰是由 random.key(999) 建立的新型別化金鑰,則索引到金鑰緩衝區會產生錯誤。

金鑰算術#

金鑰算術是一種類似的危險方式,可以從其他金鑰衍生金鑰。以避免 jax.random.split()jax.random.fold_in() 的方式衍生金鑰,透過直接操作金鑰資料,會產生一批金鑰,然後這些金鑰(取決於 PRNG 實作)可能會在批次內產生相關的隨機數

# Incorrect
key = random.PRNGKey(0)
batched_keys = key + jnp.arange(10, dtype=key.dtype)[:, None]
# Correct
key = random.PRNGKey(0)
batched_keys = random.split(key, 10)

使用 random.key(0) 建立的新型別化金鑰透過不允許對金鑰進行算術運算來解決此問題。

不慎轉置金鑰緩衝區#

使用「原始」舊式金鑰陣列,很容易意外交換批次(前導)維度和金鑰緩衝區(尾隨)維度。同樣,這可能會導致金鑰產生相關的虛擬隨機性。我們隨著時間推移看到的一種模式可以歸結為以下內容

# Incorrect
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=1)(keys)
# Correct
keys = random.split(random.PRNGKey(0))
data = jax.vmap(random.uniform, in_axes=0)(keys)

此處的錯誤很微妙。透過映射 in_axes=1,此程式碼透過組合批次中每個金鑰緩衝區的單一元素來建立新金鑰。產生的金鑰彼此不同,但實際上是以非標準方式「衍生」的。同樣,PRNG 並非設計或測試為從此類金鑰批次產生獨立的隨機串流。

使用 random.key(0) 建立的新型別化金鑰透過隱藏個別金鑰的緩衝區表示來解決此問題,而是將金鑰視為金鑰陣列的不透明元素。金鑰陣列沒有要索引、轉置或映射的尾隨「緩衝區」維度。

金鑰重複使用#

與基於狀態的 PRNG API(例如 numpy.random)不同,JAX 的函式 PRNG 在使用金鑰後不會隱式更新金鑰。

# Incorrect
key = random.PRNGKey(0)
x = random.uniform(key, (100,))
y = random.uniform(key, (100,))  # Identical values!
# Correct
key = random.PRNGKey(0)
key1, key2 = random.split(random.key(0))
x = random.uniform(key1, (100,))
y = random.uniform(key2, (100,))

我們正在積極開發工具來偵測和防止意外的金鑰重複使用。這項工作仍在進行中,但它依賴於型別化金鑰陣列。現在升級到型別化金鑰為我們設定了在我們構建這些安全功能時引入它們的基礎。

型別化 PRNG 金鑰的設計#

型別化 PRNG 金鑰實作為 JAX 中擴充資料型別的實例,其中新的 PRNG 資料型別是子資料型別。

擴充資料型別#

從使用者的角度來看,擴充資料型別 dt 具有以下使用者可見的屬性

  • jax.dtypes.issubdtype(dt, jax.dtypes.extended) 傳回 True:這是應該用於偵測資料型別是否為擴充資料型別的公開 API。

  • 它具有類別層級屬性 dt.type,該屬性傳回 numpy.generic 階層中的型別類別。這類似於 np.dtype('int32').type 如何傳回 numpy.int32,這不是資料型別,而是純量型別,並且是 numpy.generic 的子類別。

  • 與 numpy 純量型別不同,我們不允許實例化 dt.type 純量物件:這符合 JAX 將純量值表示為零維陣列的決策。

從非公開實作的角度來看,擴充資料型別具有以下屬性

  • 其型別是私有基底類別 jax._src.dtypes.ExtendedDtype 的子類別,jax._src.dtypes.ExtendedDtype 是用於擴充資料型別的非公開基底類別。ExtendedDtype 的實例類似於 np.dtype 的實例,例如 np.dtype('int32')

  • 它具有私有 _rules 屬性,該屬性允許資料型別定義其在特定運算下的行為方式。例如,當 dtype 是擴充資料型別時,jax.lax.full(shape, fill_value, dtype) 將委派給 dtype._rules.full(shape, fill_value, dtype)

為什麼要泛泛地引入擴充資料型別,而不僅僅是 PRNG?我們在內部其他地方重複使用相同的擴充資料型別機制。例如,jax._src.core.bint 物件(用於動態形狀實驗工作的有界整數型別)是另一種擴充資料型別。在最近的 JAX 版本中,它滿足上述屬性(請參閱 jax/_src/core.py#L1789-L1802)。

PRNG 資料型別#

PRNG 資料型別定義為擴充資料型別的特定情況。具體而言,此變更引入了新的公開純量型別類別 jax.dtypes.prng_key,它具有以下屬性

>>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended)
True

然後,PRNG 金鑰陣列具有具有以下屬性的資料型別

>>> key = jax.random.key(0)
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.extended)
True
>>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key)
True

除了針對一般擴充資料型別概述的 key.dtype._rules 之外,PRNG 資料型別還定義了 key.dtype._impl,其中包含定義 PRNG 實作的中繼資料。PRNG 實作目前由非公開 jax._src.prng.PRNGImpl 類別定義。目前,PRNGImpl 並非旨在成為公開 API,但我們可能會很快重新審視這一點,以允許完全自訂的 PRNG 實作。

進度#

以下是實作上述設計的關鍵提取請求的非詳盡列表。主要追蹤問題是 #9263

  • 透過 PRNGImpl 實作可插拔 PRNG:#6899

  • 實作 PRNGKeyArray,不含資料型別:#11952

  • 將「自訂元素」資料型別屬性新增至具有 _rules 屬性的 PRNGKeyArray#12167

  • 將「自訂元素型別」重新命名為「不透明資料型別」:#12170

  • 重構 bint 以使用不透明資料型別基礎結構:#12707

  • 新增 jax.random.key 以直接建立型別化金鑰:#16086

  • impl 引數新增至 keyPRNGKey#16589

  • 將「不透明資料型別」重新命名為「擴充資料型別」& 定義 jax.dtypes.extended#16824

  • 引入 jax.dtypes.prng_key 並將 PRNG 資料型別與擴充資料型別統一:#16781

  • 新增 jax_legacy_prng_key 旗標,以支援在使用舊版(原始)PRNG 金鑰時發出警告或錯誤:#17225