jax.experimental.checkify.checkify#

jax.experimental.checkify.checkify(f, errors=frozenset({<class 'jax._src.checkify.FailedCheckError'>}))[原始碼]#

fun 中的 check 呼叫函數化,並選擇性地新增執行階段錯誤檢查。

執行階段錯誤可以是使用者新增的 check() 斷言,或自動新增的檢查(例如 NaN 檢查),取決於 errors 引數。

傳回的函式將傳回 Error 物件 err 以及原始函式的輸出。err.get() 將傳回 None(如果沒有發生錯誤)或包含錯誤訊息的字串。此錯誤訊息將對應於發生的第一個錯誤。err.throw() 將在發生錯誤時引發 ValueError 並顯示錯誤訊息。

預設情況下,僅啟用使用者新增的 check() 斷言。您可以透過 errors 引數啟用自動檢查。

可以啟用的自動檢查集合,以及何時產生錯誤
  • user_checkscheck() 評估為 False。

  • nan_checks:浮點運算產生 NaN 值作為輸出。

  • div_checks:除以零。

  • index_checks:索引超出範圍。

可以透過傳入錯誤 Set (例如 errors=nan_checks) 一起啟用多個類別。可以重新組合多個集合 (例如 errors=float_checks|user_checks)

參數:
  • fun – 可包含使用者檢查的 Callable (請參閱 check())。

  • errors (frozenset[ErrorCategory]) – ErrorCategory 值的集合,用於定義已啟用的檢查集合。預設情況下,僅啟用顯式 checks (user_checks)。您也可以例如透過傳遞 float_checks 集合來啟用 NAN 和 DIV 錯誤,或者例如透過集合運算 (float_checks | user_checks) 組合多個集合

  • f (Callable[..., Out])

傳回:

一個函式,它接受與 fun 相同的引數,並傳回一個配對作為輸出,其中第一個元素是 Error 值,表示第一個失敗的 check(),第二個元素是 fun 的原始輸出。

傳回類型:

Callable[…, tuple[Error, Out]]

例如

>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>>
>>> @jax.jit
... def f(x):
...   y = jnp.sin(x)
...   return x+y
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
>>> err.throw()  
Traceback (most recent call last):
  ...
jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin