jax.custom_vjp#
- class jax.custom_vjp(fun, nondiff_argnums=())[原始碼]#
設定可 JAX 轉換的函式,以進行自訂 VJP 規則定義。
此類別旨在作為函式裝飾器使用。實例是可調用的,其行為類似於已套用裝飾器的基礎函式,但當套用反向模式微分轉換(如
jax.grad()
)時除外,在這種情況下,將使用自訂使用者提供的 VJP 規則函式,而不是追蹤到基礎函式的實作中並執行自動微分。有一個單一的實例方法defvjp()
,可用於定義自訂 VJP 規則。此裝飾器排除使用正向模式自動微分。
例如
@jax.custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
如需更詳細的介紹,請參閱教學。
- 參數:
fun (Callable[..., ReturnValue])
nondiff_argnums (Sequence[int])
- __init__(fun, nondiff_argnums=())[原始碼]#
- 參數:
fun (Callable[..., ReturnValue])
nondiff_argnums (Sequence[int])
方法
__init__
(fun[, nondiff_argnums])defvjp
(fwd, bwd[, symbolic_zeros, ...])為此實例代表的函式定義自訂 VJP 規則。