jax.numpy.gradient#
- jax.numpy.gradient(f, *varargs, axis=None, edge_order=None)[原始碼]#
計算取樣函數的數值梯度。
JAX 實作的
numpy.gradient()
。jnp.gradient
中的梯度是使用二階有限差分在取樣函數值的陣列上計算的。這不應與jax.grad()
混淆,後者透過 自動微分 計算可呼叫函數的精確梯度。- 參數:
- 傳回:
包含沿每個指定軸的數值梯度的陣列或陣列元組。
- 傳回類型:
參見
jax.grad()
:具有單一輸出的函數的自動微分。
範例
比較簡單函數的數值微分和自動微分
>>> def f(x): ... return jnp.sin(x) * jnp.exp(-x / 4) ... >>> def gradf_exact(x): ... # exact analytical gradient of f(x) ... return -f(x) / 4 + jnp.cos(x) * jnp.exp(-x / 4) ... >>> x = jnp.linspace(0, 5, 10)
>>> with jnp.printoptions(precision=2, suppress=True): ... print("numerical gradient:", jnp.gradient(f(x), x)) ... print("automatic gradient:", jax.vmap(jax.grad(f))(x)) ... print("exact gradient: ", gradf_exact(x)) ... numerical gradient: [ 0.83 0.61 0.18 -0.2 -0.43 -0.49 -0.39 -0.21 -0.02 0.08] automatic gradient: [ 1. 0.62 0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01 0.15] exact gradient: [ 1. 0.62 0.17 -0.23 -0.46 -0.51 -0.41 -0.21 -0.01 0.15]
請注意,如預期,與透過
jax.grad()
計算的自動梯度相比,數值梯度存在一些近似誤差。