jax.lax.stop_gradient#
- jax.lax.stop_gradient(x)[原始碼]#
停止梯度計算。
在操作上,
stop_gradient
是恆等函數,也就是說,它返回未更改的參數 x。然而,stop_gradient
可防止在正向或反向模式自動微分期間的梯度流動。如果有多個巢狀梯度計算,stop_gradient
會停止所有這些計算的梯度。有關此功能在何處有用的更多討論,請參閱 停止梯度。- 參數:
x (T) – 陣列或陣列的 pytree
- 返回:
傳回未更改的輸入值,但在自動微分中將被視為常數。
- 返回型別:
T
範例
考慮一個簡單的函數,它返回輸入值的平方
>>> def f1(x): ... return x ** 2 >>> x = jnp.float32(3.0) >>> f1(x) Array(9.0, dtype=float32) >>> jax.grad(f1)(x) Array(6.0, dtype=float32)
在
x
周圍使用stop_gradient
的相同函數在正常評估下將是等效的,但會返回零梯度,因為x
實際上被視為常數>>> def f2(x): ... return jax.lax.stop_gradient(x) ** 2 >>> f2(x) Array(9.0, dtype=float32) >>> jax.grad(f2)(x) Array(0.0, dtype=float32)
這在 JAX 程式碼庫中的許多地方都有使用;例如,
jax.nn.softmax()
在內部通過其最大值來正規化輸入,並且此最大值被包裝在stop_gradient
中以提高效率。有關stop_gradient
適用性的更多討論,請參閱 停止梯度。