jax.custom_vjp.defvjp#

custom_vjp.defvjp(fwd, bwd, symbolic_zeros=False, optimize_remat=False)[原始碼]#

為此實例表示的函式定義自訂 VJP 規則。

參數:
  • fwd (Callable[..., tuple[ReturnValue, Any]]) – 代表自訂 VJP 規則的前向傳遞的 Python 可呼叫物件。當沒有 nondiff_argnums 時,fwd 函式具有與底層原始函式相同的輸入簽章。它應傳回一個配對作為輸出,其中第一個元素代表原始輸出,第二個元素代表要從前向傳遞中儲存的任何「殘差」值,以供函式 bwd 在反向傳遞中使用。輸入引數和輸出配對的元素可以是陣列或巢狀 tuple/list/dict。

  • bwd (Callable[..., tuple[Any, ...]]) – 代表自訂 VJP 規則的反向傳遞的 Python 可呼叫物件。當沒有 nondiff_argnums 時,bwd 函式接受兩個引數,其中第一個引數是由 fwd 在前向傳遞中產生的「殘差」值,第二個引數是與原始函式輸出具有相同結構的輸出餘切。 bwd 的輸出必須是一個 tuple,其長度等於原始函式的引數數量,並且 tuple 元素可以是陣列或巢狀 tuple/list/dict,以便與原始輸入引數的結構匹配。

  • symbolic_zeros (bool) –

    布林值,決定是否向 fwdbwd 規則指示符號零。啟用此選項允許自訂導數規則偵測某些輸入以及某些輸出餘切何時未參與微分。如果 True

    • fwd 必須接受一個物件 (型別為 jax.custom_derivatives.CustomVJPPrimal),而不是構成原始函式引數的 pytree 中的每個葉值 x,該物件具有兩個屬性:valueperturbedvalue 欄位是原始原始引數,而 perturbed 是一個布林值。perturbed 位元指示引數是否參與微分 (即,如果它是 False,則對應的 Jacobian「列」為零)。

    • bwd 將傳遞物件,這些物件在其餘切引數中代表與未擾動值對應的靜態符號零;否則,僅傳遞標準 JAX 型別 (例如,類陣列)。

    將此選項設定為 True 允許這些規則偵測某些輸入和輸出是否未參與微分,但代價是需要特殊處理。例如

    • fwd 的簽章會變更,並且傳遞給它的物件不能直接從規則輸出。

    • bwd 規則會傳遞並非完全類陣列的物件,且這些物件無法傳遞給大多數 jax.numpy 函式。

    • 原始函式的引數中涉及的任何自訂 pytree 節點都必須在其 unflattening 函式中接受作為 fwd 規則的輸入葉給定的雙欄位記錄物件。

    預設值 False

  • optimize_remat (bool) – 布林值,一個實驗性旗標,用於在 jax.remat() 下使用此函式時啟用自動最佳化。當 fwd 規則是不透明呼叫 (例如 Pallas 核心或自訂呼叫) 時,這將最有用。預設值 False

傳回:

無。

傳回型別:

範例

>>> @jax.custom_vjp
... def f(x, y):
...   return jnp.sin(x) * y
...
>>> def f_fwd(x, y):
...   return f(x, y), (jnp.cos(x), jnp.sin(x), y)
...
>>> def f_bwd(res, g):
...   cos_x, sin_x, y = res
...   return (cos_x * g * y, sin_x * g)
...
>>> f.defvjp(f_fwd, f_bwd)
>>> x = jnp.float32(1.0)
>>> y = jnp.float32(2.0)
>>> with jnp.printoptions(precision=2):
...   print(jax.value_and_grad(f)(x, y))
(Array(1.68, dtype=float32), Array(1.08, dtype=float32))