jax.custom_jvp.defjvp#
- custom_jvp.defjvp(jvp, symbolic_zeros=False)[原始碼]#
為此實例代表的函式定義自訂 JVP 規則。
- 參數:
jvp (Callable[..., tuple[ReturnValue, ReturnValue]]) – 代表自訂 JVP 規則的 Python 可呼叫物件。當沒有
nondiff_argnums
時,jvp
函式應接受兩個引數,其中第一個是原始輸入的元組,第二個是切線輸入的元組。兩個元組的長度都等於custom_jvp
函式的參數數量。jvp
函式應產生一個配對作為輸出,其中第一個元素是原始輸出,第二個元素是切線輸出。輸入和輸出元組的元素可以是陣列或其任何巢狀元組/列表/字典。symbolic_zeros (bool) – 布林值,指示是否應將代表靜態符號零的物件傳遞到其切線引數中,以對應未受擾動的值;否則,只會傳遞標準 JAX 類型 (例如類陣列)。將此選項設定為
True
可讓 JVP 規則偵測某些輸入是否未參與微分,但代價是需要對這些物件進行特殊處理 (例如,這些物件無法傳遞到 jax.numpy 函式中)。預設值為False
。
- 傳回:
傳回
jvp
,以便defjvp
可以用作裝飾器。- 傳回類型:
Callable[…, tuple[ReturnValue, ReturnValue]]
範例
>>> @jax.custom_jvp ... def f(x, y): ... return jnp.sin(x) * y ... >>> @f.defjvp ... def f_jvp(primals, tangents): ... x, y = primals ... x_dot, y_dot = tangents ... primal_out = f(x, y) ... tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot ... return primal_out, tangent_out
>>> x = jnp.float32(1.0) >>> y = jnp.float32(2.0) >>> with jnp.printoptions(precision=2): ... print(jax.value_and_grad(f)(x, y)) (Array(1.68, dtype=float32), Array(1.08, dtype=float32))