jax.debug.print#
- jax.debug.print(fmt, *args, ordered=False, **kwargs)[原始碼]#
印出值並在 staged out JAX 函數中運作。
此函數不適用於 f-strings,因為格式化會延遲。因此,請寫成
jax.debug.print("hello {bar}", bar=bar)
,而不是jax.debug.print(f"hello {bar}")
。此函數是
jax.debug.callback()
的精簡便利包裝函式。實作基本上是def debug_print(fmt: str, *args, **kwargs): jax.debug.callback( lambda *args, **kwargs: print(fmt.format(*args, **kwargs)), *args, **kwargs)
直接呼叫
jax.debug.callback()
而不是這個便利包裝函式可能很有用。例如,若要在記錄中取得除錯列印,您可以將jax.debug.callback()
與logging.log
一起使用。