jax.custom_vjp#

class jax.custom_vjp(fun, nondiff_argnums=())[原始碼]#

設定可 JAX 轉換的函式,以進行自訂 VJP 規則定義。

此類別旨在作為函式裝飾器使用。實例是可調用的,其行為類似於已套用裝飾器的基礎函式,但當套用反向模式微分轉換(如 jax.grad())時除外,在這種情況下,將使用自訂使用者提供的 VJP 規則函式,而不是追蹤到基礎函式的實作中並執行自動微分。有一個單一的實例方法 defvjp(),可用於定義自訂 VJP 規則。

此裝飾器排除使用正向模式自動微分。

例如

@jax.custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)

如需更詳細的介紹,請參閱教學

參數:
  • fun (Callable[..., ReturnValue])

  • nondiff_argnums (Sequence[int])

__init__(fun, nondiff_argnums=())[原始碼]#
參數:
  • fun (Callable[..., ReturnValue])

  • nondiff_argnums (Sequence[int])

方法

__init__(fun[, nondiff_argnums])

defvjp(fwd, bwd[, symbolic_zeros, ...])

為此實例代表的函式定義自訂 VJP 規則。