jax.grad#

jax.grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[原始碼]#

建立一個函式,用於評估 fun 的梯度。

參數:
  • fun (Callable) – 要微分的函式。由 argnums 指定位置的引數應為陣列、純量或標準 Python 容器。由 argnums 指定位置的引數陣列必須為非精確 (即,浮點或複數) 型別。它應傳回純量 (包含形狀為 () 的陣列,但不包含形狀為 (1,) 等的陣列)

  • argnums (int | Sequence[int]) – 選用,整數或整數序列。指定要針對哪個位置引數進行微分 (預設為 0)。

  • has_aux (bool) – 選用,布林值。指示 fun 是否傳回一對值,其中第一個元素被視為要微分的數學函式的輸出,而第二個元素是輔助資料。預設為 False。

  • holomorphic (bool) – 選用,布林值。指示 fun 是否保證為全純函數。如果為 True,則輸入和輸出必須為複數。預設為 False。

  • allow_int (bool) – 選用,布林值。是否允許針對整數值輸入進行微分。整數輸入的梯度將具有微不足道的向量空間 dtype (float0)。預設為 False。

  • reduce_axes (Sequence[AxisName])

傳回:

fun 具有相同引數的函式,用於評估 fun 的梯度。如果 argnums 是整數,則梯度具有與該整數指示的位置引數相同的形狀和型別。如果 argnums 是整數元組,則梯度是值元組,其形狀和型別與對應的引數相同。如果 has_aux 為 True,則傳回 (梯度,輔助資料) 對。

傳回型別:

Callable

例如

>>> import jax
>>>
>>> grad_tanh = jax.grad(jax.numpy.tanh)
>>> print(grad_tanh(0.2))
0.961043