jax.jacfwd#

jax.jacfwd(fun, argnums=0, has_aux=False, holomorphic=False)[source]#

使用前向模式 AD 逐列計算 fun 的 Jacobian 矩陣。

參數:
  • fun (Callable) – 要計算 Jacobian 矩陣的函數。

  • argnums (int | Sequence[int]) – 選填,整數或整數序列。指定要對哪個位置參數求微分(預設值為 0)。

  • has_aux (bool) – 選填,布林值。指示 fun 是否傳回一個配對,其中第一個元素被視為要微分的數學函數的輸出,而第二個元素是輔助資料。預設值為 False。

  • holomorphic (bool) – 選填,布林值。指示 fun 是否保證為全純函數。預設值為 False。

傳回:

一個與 fun 具有相同引數的函數,使用前向模式自動微分來計算 fun 的 Jacobian 矩陣。如果 has_aux 為 True,則傳回 (jacobian, auxiliary_data) 配對。

傳回型別:

Callable

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...   return jnp.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.])))
[[ 1.       0.       0.     ]
 [ 0.       0.       5.     ]
 [ 0.      16.      -2.     ]
 [ 1.6209   0.       0.84147]]