jax.numpy.gradient#

jax.numpy.gradient(f, *varargs, axis=None, edge_order=None)[原始碼]#

計算取樣函數的數值梯度。

JAX 實作的 numpy.gradient()

jnp.gradient 中的梯度是使用二階有限差分在取樣函數值的陣列上計算的。這不應與 jax.grad() 混淆,後者透過 自動微分 計算可呼叫函數的精確梯度。

參數:
  • f (ArrayLike) – 函數值的N維陣列。

  • varargs (ArrayLike) –

    指定函數評估間隔的可選純量或陣列列表。選項為

    • 未指定:所有維度中的單位間隔。

    • 單一純量:所有維度中的常數間隔。

    • N 個值:指定每個維度中的不同間隔

      • 純量值表示該維度中的常數間隔。

      • 陣列值必須符合對應維度的長度,並指定評估 f 的座標。

  • edge_order (int | None) – 在 JAX 中未實作

  • axis (int | Sequence[int] | None) – 指定要沿其計算梯度的軸的整數或整數元組。如果為 None (預設值),則計算沿所有軸的梯度。

傳回:

包含沿每個指定軸的數值梯度的陣列或陣列元組。

傳回類型:

Array | list[Array]

參見

  • jax.grad():具有單一輸出的函數的自動微分。

範例

比較簡單函數的數值微分和自動微分

>>> def f(x):
...   return jnp.sin(x) * jnp.exp(-x / 4)
...
>>> def gradf_exact(x):
...   # exact analytical gradient of f(x)
...   return -f(x) / 4 + jnp.cos(x) * jnp.exp(-x / 4)
...
>>> x = jnp.linspace(0, 5, 10)
>>> with jnp.printoptions(precision=2, suppress=True):
...   print("numerical gradient:", jnp.gradient(f(x), x))
...   print("automatic gradient:", jax.vmap(jax.grad(f))(x))
...   print("exact gradient:    ", gradf_exact(x))
...
numerical gradient: [ 0.83  0.61  0.18 -0.2  -0.43 -0.49 -0.39 -0.21 -0.02  0.08]
automatic gradient: [ 1.    0.62  0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01  0.15]
exact gradient:     [ 1.    0.62  0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01  0.15]

請注意,如預期,與透過 jax.grad() 計算的自動梯度相比,數值梯度存在一些近似誤差。