jax.named_call#
- jax.named_call(fun, *, name=None)[原始碼]#
在分段輸出 JAX 計算時,為函式新增使用者指定的名稱。
當為了即時編譯到 XLA (或其他後端,例如 TensorFlow) 而分段輸出計算時,JAX 會執行您的 Python 程式,但預設不會保留與其關聯的任何函式名稱或其他 metadata。這可能會使除錯程式的分段輸出 (和/或編譯) 表示法變得複雜,因為每個正在執行的操作的上下文資訊有限。
named_call 告訴 JAX 將給定的函式分段輸出為具有特定名稱的子計算。當使用 XLA 編譯分段輸出的程式時,這些命名的子計算會被保留,並顯示在除錯工具中,例如 TensorBoard 中的 TensorFlow Profiler。當使用
experimental.jax2tf.convert()
將 JAX 程式分段輸出到 TensorFlow 時,名稱也會被保留。- 參數:
fun (F) – 要包裝的函式。這可以是任何可呼叫物件。
name (str | None | None) – 選用。用於命名在名稱範圍內建立的所有子計算的前綴。如果未指定,則使用 fun.__name__。
- 傳回:
包裝在 name_scope 中的 fun 版本。
- 傳回型別:
F