jax.debug 模組#

執行階段值偵錯工具#

編譯後的列印和中斷點 說明如何使用 JAX 的執行階段值偵錯功能。

callback(callback, *args[, ordered])

呼叫可階段化的 Python 回呼。

print(fmt, *args[, ordered])

列印值並在階段外 JAX 函式中運作。

breakpoint(*[, backend, filter_frames, ...])

在程式中的某個點進入中斷點。

分片偵錯工具#

啟用檢查和視覺化階段函式內(和外)陣列分片的函式。

inspect_array_sharding(value, *, callback)

啟用在 JIT 編譯函式內檢查陣列分片。

visualize_array_sharding(arr, **kwargs)

視覺化陣列的分片。

visualize_sharding(shape, sharding, *[, ...])

使用 rich 視覺化 Sharding