jax.named_scope#
- jax.named_scope(name)[source]#
一個上下文管理器,將使用者指定的名稱添加到 JAX 名稱堆疊。
當為了即時編譯到 XLA(或其他後端,例如 TensorFlow)而暫存計算時,JAX 預設不會保留它遇到的 Python 函數的名稱(或其他來源元數據)。這可能會使除錯已暫存(和/或已編譯)的程式表示變得複雜,因為每個正在執行的操作的上下文資訊有限。
named_scope
告訴 JAX 暫存給定的函數,並在底層操作上添加額外的註解。JAX 在內部名稱堆疊中追蹤這些註解。當使用 XLA 編譯暫存的程式時,這些註解會被保留,並顯示在除錯工具中,例如 TensorBoard 中的 TensorFlow Profiler。當使用experimental.jax2tf.convert()
將 JAX 程式暫存到 TensorFlow 時,名稱也會被保留。- 參數:
name (str) – 用於命名在名稱範圍內建立的所有操作的前綴。
- Yields:
Yields
None
,但進入一個上下文,其中 name 將被附加到活動名稱堆疊。- 回傳類型:
Generator[None, None, None]
範例
named_scope
可以用作編譯函數內部的上下文管理器>>> import jax >>> >>> @jax.jit ... def layer(w, x): ... with jax.named_scope("dot_product"): ... logits = w.dot(x) ... with jax.named_scope("activation"): ... return jax.nn.relu(logits)
它也可以用作裝飾器
>>> @jax.jit ... @jax.named_scope("layer") ... def layer(w, x): ... logits = w.dot(x) ... return jax.nn.relu(logits)