jax.experimental.checkify.check#

jax.experimental.checkify.check(pred, msg, *fmt_args, debug=False, **fmt_kwargs)[原始碼]#

檢查述詞,如果述詞為 False,則新增帶有 msg 的錯誤。

這是一個有副作用的操作,無法被 staged (jitted/scanned/…)。在 staging 具有檢查的函式之前,請先 checkify() 它!

參數:
  • pred (Bool) – 如果為 False,則會新增 FailedCheckError 錯誤。

  • msg (str) – 如果新增錯誤時的錯誤訊息。可以是格式字串。

  • debug (bool) – 是否開啟偵錯模式。如果為 True,則在執行期間將移除檢查。如果為 False,則必須使用 checkify.checkify 將檢查功能化。

  • fmt_argsmsg 的位置和關鍵字格式化引數,例如:check(.., "check failed on values {} and {named_arg}", x, named_arg=y) 請注意,這些引數可以是追蹤值,允許您將執行時期值新增至錯誤訊息。請注意,即使沒有發生錯誤,追蹤這些執行時期陣列也會增加您的記憶體使用量。

  • fmt_kwargsmsg 的位置和關鍵字格式化引數,例如:check(.., "check failed on values {} and {named_arg}", x, named_arg=y) 請注意,這些引數可以是追蹤值,允許您將執行時期值新增至錯誤訊息。請注意,即使沒有發生錯誤,追蹤這些執行時期陣列也會增加您的記憶體使用量。

傳回類型:

None

例如

>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
...   checkify.check(x>0, "{x} needs to be positive!", x=x)
...   return 1/x
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(-3.)
>>> err.throw()  
Traceback (most recent call last):
  ...
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!