jax.custom_jvp#
- class jax.custom_jvp(fun, nondiff_argnums=())[原始碼]#
設定一個 JAX 可轉換函式,用於自訂 JVP 規則定義。
此類別旨在作為函式裝飾器使用。實例是可呼叫的,其行為類似於裝飾器應用於的底層函式,但當應用微分轉換(如
jax.jvp()
或jax.grad()
)時,會使用自訂使用者提供的 JVP 規則函式,而不是追蹤到底層函式的實作並執行自動微分。有兩種實例方法可用於定義自訂 JVP 規則:
defjvp()
用於為函式的所有輸入定義單一自訂 JVP 規則,為了方便起見,defjvps()
包裝了defjvp()
,並允許您為函式相對於其每個引數的偏導數提供單獨的定義。例如
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
如需更詳細的介紹,請參閱教學。
- 參數:
fun (Callable[..., ReturnValue])
nondiff_argnums (Sequence[int])
- __init__(fun, nondiff_argnums=())[原始碼]#
- 參數:
fun (Callable[..., ReturnValue])
nondiff_argnums (Sequence[int])
方法
__init__
(fun[, nondiff_argnums])defjvp
(jvp[, symbolic_zeros])為此實例表示的函式定義自訂 JVP 規則。
defjvps
(*jvps)用於為每個引數分別定義 JVP 的便利包裝器。
屬性
jvp
symbolic_zeros
fun
nondiff_argnums