jax.value_and_grad#

jax.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[原始碼]#

建立一個函式,同時評估 funfun 的梯度。

參數:
  • fun (Callable) – 要微分的函式。其在 argnums 指定位置的引數應為陣列、純量或標準 Python 容器。它應傳回純量 (包括形狀為 () 的陣列,但不包括形狀為 (1,) 等的陣列)。

  • argnums (int | Sequence[int]) – 選用,整數或整數序列。指定要針對哪個位置引數進行微分 (預設值為 0)。

  • has_aux (bool) – 選用,布林值。指出 fun 是否傳回一對值,其中第一個元素被視為要微分的數學函式的輸出,第二個元素是輔助資料。預設值為 False。

  • holomorphic (bool) – 選用,布林值。指出 fun 是否保證為全純函數。如果為 True,則輸入和輸出必須為複數。預設值為 False。

  • allow_int (bool) – 選用,布林值。是否允許針對整數值輸入進行微分。整數輸入的梯度將具有微不足道的向量空間 dtype (float0)。預設值為 False。

  • reduce_axes (Sequence[AxisName])

傳回值:

一個與 fun 具有相同引數的函式,它同時評估 funfun 的梯度,並將它們作為一對值傳回 (一個雙元素元組)。如果 argnums 是一個整數,則梯度具有與該整數指示的位置引數相同的形狀和類型。如果 argnums 是一個整數序列,則梯度是一個值元組,其形狀和類型與相應的引數相同。如果 has_aux 為 True,則傳回 ((值, 輔助資料), 梯度) 的元組。

傳回類型:

Callable[…, tuple[Any, Any]]