jax.jacfwd#
- jax.jacfwd(fun, argnums=0, has_aux=False, holomorphic=False)[source]#
使用前向模式 AD 逐列計算
fun
的 Jacobian 矩陣。- 參數:
fun (Callable) – 要計算 Jacobian 矩陣的函數。
argnums (int | Sequence[int]) – 選填,整數或整數序列。指定要對哪個位置參數求微分(預設值為
0
)。has_aux (bool) – 選填,布林值。指示
fun
是否傳回一個配對,其中第一個元素被視為要微分的數學函數的輸出,而第二個元素是輔助資料。預設值為 False。holomorphic (bool) – 選填,布林值。指示
fun
是否保證為全純函數。預設值為 False。
- 傳回:
一個與
fun
具有相同引數的函數,使用前向模式自動微分來計算fun
的 Jacobian 矩陣。如果has_aux
為 True,則傳回 (jacobian, auxiliary_data) 配對。- 傳回型別:
Callable
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] [ 1.6209 0. 0.84147]]