jax.named_scope#

jax.named_scope(name)[source]#

一個上下文管理器,將使用者指定的名稱添加到 JAX 名稱堆疊。

當為了即時編譯到 XLA(或其他後端,例如 TensorFlow)而暫存計算時,JAX 預設不會保留它遇到的 Python 函數的名稱(或其他來源元數據)。這可能會使除錯已暫存(和/或已編譯)的程式表示變得複雜,因為每個正在執行的操作的上下文資訊有限。

named_scope 告訴 JAX 暫存給定的函數,並在底層操作上添加額外的註解。JAX 在內部名稱堆疊中追蹤這些註解。當使用 XLA 編譯暫存的程式時,這些註解會被保留,並顯示在除錯工具中,例如 TensorBoard 中的 TensorFlow Profiler。當使用 experimental.jax2tf.convert() 將 JAX 程式暫存到 TensorFlow 時,名稱也會被保留。

參數:

name (str) – 用於命名在名稱範圍內建立的所有操作的前綴。

Yields:

Yields None,但進入一個上下文,其中 name 將被附加到活動名稱堆疊。

回傳類型:

Generator[None, None, None]

範例

named_scope 可以用作編譯函數內部的上下文管理器

>>> import jax
>>>
>>> @jax.jit
... def layer(w, x):
...   with jax.named_scope("dot_product"):
...     logits = w.dot(x)
...   with jax.named_scope("activation"):
...     return jax.nn.relu(logits)

它也可以用作裝飾器

>>> @jax.jit
... @jax.named_scope("layer")
... def layer(w, x):
...   logits = w.dot(x)
...   return jax.nn.relu(logits)