jax.profiler.annotate_function#
- jax.profiler.annotate_function(func, name=None, **decorator_kwargs)[原始碼]#
為函式執行產生追蹤事件的裝飾器。
例如
>>> @jax.profiler.annotate_function ... def f(x): ... return jnp.dot(x, x.T).block_until_ready() >>> >>> result = f(jnp.ones((1000, 1000)))
如果在 TensorBoard 追蹤程序的同時執行函式,這將導致在追蹤時間軸上顯示 “f” 事件。
可以透過
functools.partial()
將引數傳遞給裝飾器。>>> from functools import partial
>>> @partial(jax.profiler.annotate_function, name="event_name") ... def f(x): ... return jnp.dot(x, x.T).block_until_ready()
>>> result = f(jnp.ones((1000, 1000)))
- 參數:
func (Callable)
name (str | None | None)