jax.debug.print#

jax.debug.print(fmt, *args, ordered=False, **kwargs)[原始碼]#

印出值並在 staged out JAX 函數中運作。

此函數適用於 f-strings,因為格式化會延遲。因此,請寫成 jax.debug.print("hello {bar}", bar=bar),而不是 jax.debug.print(f"hello {bar}")

此函數是 jax.debug.callback() 的精簡便利包裝函式。實作基本上是

def debug_print(fmt: str, *args, **kwargs):
  jax.debug.callback(
      lambda *args, **kwargs: print(fmt.format(*args, **kwargs)),
      *args, **kwargs)

直接呼叫 jax.debug.callback() 而不是這個便利包裝函式可能很有用。例如,若要在記錄中取得除錯列印,您可以將 jax.debug.callback()logging.log 一起使用。

參數:
  • fmt (str) – 格式字串,例如 "hello {x}",將用於格式化輸入引數,如 str.format。請參閱 Python 文件中的 字串格式化格式字串語法

  • *args – 要格式化的位置引數列表,如同傳遞至 fmt.format

  • ordered (bool) – 僅限關鍵字引數,用於指示 staged out 計算是否會強制執行此 jax.debug.print 相對於其他 ordered jax.debug.print 呼叫的順序。

  • **kwargs – 要格式化的其他關鍵字引數,如同傳遞至 fmt.format

返回類型:

None