jax.custom_jvp.defjvp#

custom_jvp.defjvp(jvp, symbolic_zeros=False)[原始碼]#

為此實例代表的函式定義自訂 JVP 規則。

參數:
  • jvp (Callable[..., tuple[ReturnValue, ReturnValue]]) – 代表自訂 JVP 規則的 Python 可呼叫物件。當沒有 nondiff_argnums 時,jvp 函式應接受兩個引數,其中第一個是原始輸入的元組,第二個是切線輸入的元組。兩個元組的長度都等於 custom_jvp 函式的參數數量。jvp 函式應產生一個配對作為輸出,其中第一個元素是原始輸出,第二個元素是切線輸出。輸入和輸出元組的元素可以是陣列或其任何巢狀元組/列表/字典。

  • symbolic_zeros (bool) – 布林值,指示是否應將代表靜態符號零的物件傳遞到其切線引數中,以對應未受擾動的值;否則,只會傳遞標準 JAX 類型 (例如類陣列)。將此選項設定為 True 可讓 JVP 規則偵測某些輸入是否未參與微分,但代價是需要對這些物件進行特殊處理 (例如,這些物件無法傳遞到 jax.numpy 函式中)。預設值為 False

傳回:

傳回 jvp,以便 defjvp 可以用作裝飾器。

傳回類型:

Callable[…, tuple[ReturnValue, ReturnValue]]

範例

>>> @jax.custom_jvp
... def f(x, y):
...   return jnp.sin(x) * y
...
>>> @f.defjvp
... def f_jvp(primals, tangents):
...   x, y = primals
...   x_dot, y_dot = tangents
...   primal_out = f(x, y)
...   tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
...   return primal_out, tangent_out
>>> x = jnp.float32(1.0)
>>> y = jnp.float32(2.0)
>>> with jnp.printoptions(precision=2):
...   print(jax.value_and_grad(f)(x, y))
(Array(1.68, dtype=float32), Array(1.08, dtype=float32))