jax.debug.callback#

jax.debug.callback(callback, *args, ordered=False, **kwargs)[source]#

呼叫可暫存的 Python 回呼。

更多說明,請參閱外部回呼

jax.debug.callback 讓您傳入一個 Python 函數,該函數可以在暫存的 JAX 程式中呼叫。jax.debug.callback 遵循現有的 JAX 轉換操作語義,因此不會意識到副作用。這表示在存在高階基本運算和轉換的情況下,效果可能會被丟棄、複製或重新排序。

我們希望有這種行為,因為我們希望 jax.debug.callback 是「無害的」,也就是說,我們希望這些基本運算在盡可能少地更改 JAX 計算的同時,盡可能多地揭示有關它們的資訊,例如計算的哪些部分被複製或丟棄。

參數:
  • callback (Callable[..., None]) – 一個返回 None 的 Python 可呼叫物件。

  • *args (Any) – 回呼的位置引數。

  • ordered (bool) – 一個僅限關鍵字的引數,用於指示暫存的計算是否會強制執行此回呼相對於其他排序回呼的順序。

  • **kwargs (Any) – 回呼的關鍵字引數。

返回值:

None

返回類型:

None

另請參閱