jax.experimental.checkify.checkify#
- jax.experimental.checkify.checkify(f, errors=frozenset({<class 'jax._src.checkify.FailedCheckError'>}))[原始碼]#
將 fun 中的 check 呼叫函數化,並選擇性地新增執行階段錯誤檢查。
執行階段錯誤可以是使用者新增的
check()
斷言,或自動新增的檢查(例如 NaN 檢查),取決於errors
引數。傳回的函式將傳回 Error 物件 err 以及原始函式的輸出。
err.get()
將傳回None
(如果沒有發生錯誤)或包含錯誤訊息的字串。此錯誤訊息將對應於發生的第一個錯誤。err.throw()
將在發生錯誤時引發 ValueError 並顯示錯誤訊息。預設情況下,僅啟用使用者新增的
check()
斷言。您可以透過errors
引數啟用自動檢查。- 可以啟用的自動檢查集合,以及何時產生錯誤
user_checks
:check()
評估為 False。nan_checks
:浮點運算產生 NaN 值作為輸出。div_checks
:除以零。index_checks
:索引超出範圍。
可以透過傳入錯誤 Set (例如
errors=nan_checks
) 一起啟用多個類別。可以重新組合多個集合 (例如errors=float_checks|user_checks
)- 參數:
fun – 可包含使用者檢查的 Callable (請參閱
check()
)。errors (frozenset[ErrorCategory]) – ErrorCategory 值的集合,用於定義已啟用的檢查集合。預設情況下,僅啟用顯式
checks
(user_checks
)。您也可以例如透過傳遞float_checks
集合來啟用 NAN 和 DIV 錯誤,或者例如透過集合運算 (float_checks | user_checks
) 組合多個集合f (Callable[..., Out])
- 傳回:
一個函式,它接受與
fun
相同的引數,並傳回一個配對作為輸出,其中第一個元素是Error
值,表示第一個失敗的check()
,第二個元素是fun
的原始輸出。- 傳回類型:
例如
>>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> >>> @jax.jit ... def f(x): ... y = jnp.sin(x) ... return x+y >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf) >>> err.throw() Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin