jax.jvp#

jax.jvp(fun, primals, tangents, has_aux=False)[原始碼]#

計算 fun 的 (前向模式) 雅可比矩陣-向量積。

參數:
  • fun (Callable) – 要微分的函式。其引數應為陣列、純量,或陣列或純量的標準 Python 容器。它應傳回陣列、純量,或陣列或純量的標準 Python 容器。

  • primals – 應評估 fun 的雅可比矩陣的原始值。應為引數的元組或列表,且其長度應等於 fun 的位置參數數量。

  • tangents – 應評估雅可比矩陣-向量積的切線向量。應為切線的元組或列表,其樹狀結構和陣列形狀應與 primals 相同。

  • has_aux (bool) – 選用,布林值。指出 fun 是否傳回一對值,其中第一個元素被視為要微分的數學函式的輸出,而第二個元素是輔助資料。預設值為 False。

傳回值:

如果 has_auxFalse,則傳回 (primals_out, tangents_out) 對,其中 primals_outfun(*primals),而 tangents_out 為在 primals 處評估的 function 的雅可比矩陣-向量積,帶有 tangentstangents_out 值具有與 primals_out 相同的 Python 樹狀結構和形狀。如果 has_auxTrue,則傳回 (primals_out, tangents_out, aux) 元組,其中 auxfun 傳回的輔助資料。

傳回類型:

tuple[Any, …]

例如

>>> import jax
>>>
>>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,))
>>> print(primals)
0.09983342
>>> print(tangents)
0.19900084