jax.custom_jvp#

class jax.custom_jvp(fun, nondiff_argnums=())[原始碼]#

設定一個 JAX 可轉換函式,用於自訂 JVP 規則定義。

此類別旨在作為函式裝飾器使用。實例是可呼叫的,其行為類似於裝飾器應用於的底層函式,但當應用微分轉換(如 jax.jvp()jax.grad())時,會使用自訂使用者提供的 JVP 規則函式,而不是追蹤到底層函式的實作並執行自動微分。

有兩種實例方法可用於定義自訂 JVP 規則:defjvp() 用於為函式的所有輸入定義單一自訂 JVP 規則,為了方便起見,defjvps() 包裝了 defjvp(),並允許您為函式相對於其每個引數的偏導數提供單獨的定義。

例如

@jax.custom_jvp
def f(x, y):
  return jnp.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
  return primal_out, tangent_out

如需更詳細的介紹,請參閱教學

參數:
  • fun (Callable[..., ReturnValue])

  • nondiff_argnums (Sequence[int])

__init__(fun, nondiff_argnums=())[原始碼]#
參數:
  • fun (Callable[..., ReturnValue])

  • nondiff_argnums (Sequence[int])

方法

__init__(fun[, nondiff_argnums])

defjvp(jvp[, symbolic_zeros])

為此實例表示的函式定義自訂 JVP 規則。

defjvps(*jvps)

用於為每個引數分別定義 JVP 的便利包裝器。

屬性

jvp

symbolic_zeros

fun

nondiff_argnums