jax.lax.stop_gradient#

jax.lax.stop_gradient(x)[原始碼]#

停止梯度計算。

在操作上,stop_gradient 是恆等函數,也就是說,它返回未更改的參數 x。然而,stop_gradient 可防止在正向或反向模式自動微分期間的梯度流動。如果有多個巢狀梯度計算,stop_gradient 會停止所有這些計算的梯度。有關此功能在何處有用的更多討論,請參閱 停止梯度

參數:

x (T) – 陣列或陣列的 pytree

返回:

傳回未更改的輸入值,但在自動微分中將被視為常數。

返回型別:

T

範例

考慮一個簡單的函數,它返回輸入值的平方

>>> def f1(x):
...   return x ** 2
>>> x = jnp.float32(3.0)
>>> f1(x)
Array(9.0, dtype=float32)
>>> jax.grad(f1)(x)
Array(6.0, dtype=float32)

x 周圍使用 stop_gradient 的相同函數在正常評估下將是等效的,但會返回零梯度,因為 x 實際上被視為常數

>>> def f2(x):
...   return jax.lax.stop_gradient(x) ** 2
>>> f2(x)
Array(9.0, dtype=float32)
>>> jax.grad(f2)(x)
Array(0.0, dtype=float32)

這在 JAX 程式碼庫中的許多地方都有使用;例如,jax.nn.softmax() 在內部通過其最大值來正規化輸入,並且此最大值被包裝在 stop_gradient 中以提高效率。有關 stop_gradient 適用性的更多討論,請參閱 停止梯度