jax.jacrev#
- jax.jacrev(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[原始碼]#
使用反向模式 AD 逐列評估
fun
的 Jacobian 矩陣。- 參數:
fun (Callable) – 要計算 Jacobian 矩陣的函式。
argnums (int | Sequence[int]) – 選用,整數或整數序列。指定要針對哪個位置引數進行微分(預設為
0
)。has_aux (bool) – 選用,布林值。指示
fun
是否傳回一個配對,其中第一個元素被視為要微分的數學函式的輸出,而第二個元素是輔助資料。預設為 False。holomorphic (bool) – 選用,布林值。指示
fun
是否保證為全純函式。預設為 False。allow_int (bool) – 選用,布林值。是否允許針對整數值輸入進行微分。整數輸入的梯度將具有微不足道的向量空間 dtype (float0)。預設為 False。
- 傳回值:
與
fun
具有相同引數的函式,該函式使用反向模式自動微分來評估fun
的 Jacobian 矩陣。如果has_aux
為 True,則傳回 (jacobian, auxiliary_data) 配對。- 傳回類型:
Callable
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacrev(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] [ 1.6209 0. 0.84147]]