jax.jvp#
- jax.jvp(fun, primals, tangents, has_aux=False)[原始碼]#
計算
fun
的 (前向模式) 雅可比矩陣-向量積。- 參數:
fun (Callable) – 要微分的函式。其引數應為陣列、純量,或陣列或純量的標準 Python 容器。它應傳回陣列、純量,或陣列或純量的標準 Python 容器。
primals – 應評估
fun
的雅可比矩陣的原始值。應為引數的元組或列表,且其長度應等於fun
的位置參數數量。tangents – 應評估雅可比矩陣-向量積的切線向量。應為切線的元組或列表,其樹狀結構和陣列形狀應與
primals
相同。has_aux (bool) – 選用,布林值。指出
fun
是否傳回一對值,其中第一個元素被視為要微分的數學函式的輸出,而第二個元素是輔助資料。預設值為 False。
- 傳回值:
如果
has_aux
為False
,則傳回(primals_out, tangents_out)
對,其中primals_out
為fun(*primals)
,而tangents_out
為在primals
處評估的function
的雅可比矩陣-向量積,帶有tangents
。tangents_out
值具有與primals_out
相同的 Python 樹狀結構和形狀。如果has_aux
為True
,則傳回(primals_out, tangents_out, aux)
元組,其中aux
為fun
傳回的輔助資料。- 傳回類型:
tuple[Any, …]
例如
>>> import jax >>> >>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(primals) 0.09983342 >>> print(tangents) 0.19900084