jax.debug.breakpoint#

jax.debug.breakpoint(*, backend=None, filter_frames=True, num_frames=None, ordered=False, token=None, **kwargs)[原始碼]#

在程式中的某個點進入中斷點。

參數:
  • backend (str | None | None) – 要使用的除錯器後端。預設情況下,選取最高優先順序的除錯器,並且在沒有其他已註冊的除錯器的情況下,回復為 CLI 除錯器。

  • filter_frames (bool) – 是否從追蹤中篩選掉 JAX 內部堆疊框架。由於某些程式庫(例如 Flax)也使用 JAX 的堆疊框架篩選系統,因此此選項也可能會影響是否篩選程式庫中的堆疊框架。

  • num_frames (int | None | None) – 在互動式除錯器中可供檢查的目前堆疊框架之上的框架數。

  • ordered (bool) – 僅限關鍵字引數,用於指示已分段輸出的計算是否將強制執行此 jax.debug.breakpoint 相對於其他已排序的 jax.debug.breakpointjax.debug.print 呼叫的順序。

  • token – 僅限關鍵字引數;ordered 的替代方案。如果使用,則應傳遞 JAX 陣列(或 JAX 陣列的 pytree),並且中斷點將在其值計算完成後執行一次。這將不變地傳回,並且應傳遞回計算。如果傳回值在稍後的計算中未使用,則將修剪整個計算,並且不會執行此中斷點。

傳回值:

如果傳遞 token,則傳回其值不變。否則,傳回 None