jax.experimental.key_reuse 模組#

實驗性金鑰重複使用檢查#

此模組包含用於偵測 JAX 程式中隨機金鑰重複使用的實驗性功能。它正在積極開發中,此處的 API 可能會變更。以下用法需要 JAX 版本 0.4.26 或更新版本。

可以使用 jax_debug_key_reuse 設定啟用金鑰重複使用檢查。這可以使用以下方式全域設定

>>> jax.config.update('jax_debug_key_reuse', True)  

或者可以使用 KeyReuseError jax.debug_key_reuse() 內容管理器在本機啟用。啟用後,重複使用相同的金鑰兩次將導致

>>> import jax
>>> with jax.debug_key_reuse(True):
...   key = jax.random.key(0)
...   val1 = jax.random.normal(key)
...   val2 = jax.random.normal(key)  
Traceback (most recent call last):
 ...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0

金鑰重複使用檢查器目前為實驗性,但在未來我們可能會預設啟用它。