jax.experimental.checkify.check_error#

jax.experimental.checkify.check_error(error)[原始碼]#

如果 error 代表失敗,則引發例外。透過 checkify() 函數化。

此函數的語意等同於

>>> def check_error(err: Error) -> None:
...   err.throw()  # can raise ValueError

但與該實作不同,check_error 可以使用 checkify() 轉換進行函數化。

此函數類似於 check(),但具有不同的簽名:check() 接受布林述詞和新的錯誤訊息字串作為引數,而此函數接受 Error 值作為引數。check() 和此函數都會在失敗時引發 Python 例外 (副作用),因此無法透過 jit()pmap()scan() 等進行 staged out。兩者也可以透過使用 checkify() 進行函數化。

但與 check() 不同,此函數更像是 checkify() 的直接反向:checkify() 接受可能引發 Python 例外的函數作為輸入,並產生一個沒有該副作用的新函數,但會產生 Error 值作為輸出,此 check_error 函數可以接受 Error 值作為輸入,並可能產生引發例外的副作用。也就是說,雖然 checkify() 從可函數化的例外副作用轉變為錯誤值,但此 check_error 從錯誤值轉變為可函數化的例外副作用。

當您想要將由 Error 值表示的檢查 (透過 checkify() 函數化 checks 所產生) 轉回 Python 例外時,check_error 非常有用。

參數:

error (Error) – 要檢查的錯誤。

返回類型:

None

例如,您可能想要透過 checkify 函數化程式碼的一部分,透過 jit() staged out 您的函數化程式碼,然後在 jit() 之外重新注入您的錯誤值

>>> import jax
>>> from jax.experimental import checkify
>>> def f(x):
...   checkify.check(x>0, "must be positive!")
...   return x
>>> def with_inner_jit(x):
...   checked_f = checkify.checkify(f)
...   # a checkified function can be jitted
...   error, out = jax.jit(checked_f)(x)
...   checkify.check_error(error)
...   return out
>>> _ = with_inner_jit(1)  # no failed check
>>> with_inner_jit(-1)  
Traceback (most recent call last):
  ...
jax._src.JaxRuntimeError: must be positive!
>>> # can re-checkify
>>> error, _ = checkify.checkify(with_inner_jit)(-1)