jax.custom_gradient#

jax.custom_gradient(fun)[原始碼]#

用於定義自訂 VJP 規則 (又名自訂梯度) 的便利函數。

雖然定義自訂 VJP 規則的標準方式是透過 jax.custom_vjp,但 custom_gradient 便利包裝函式遵循 TensorFlow 的 tf.custom_gradient API。此處的不同之處在於 custom_gradient 可以用作裝飾器,裝飾在一個函數上,該函數會同時傳回原始值 (表示要微分的數學函數的輸出) 和 VJP (梯度) 函數。請參閱 https://tensorflow.dev.org.tw/api_docs/python/tf/custom_gradient

如果要微分的數學函數具有類似 Haskell 的簽名 a -> b,則 Python 可呼叫物件 fun 應具有簽名 a -> (b, CT b --o CT a),其中我們使用 CT x 表示 x 的餘切類型,並使用 --o 箭頭表示線性函數。請參閱以下範例。也就是說,fun 應傳回一個配對,其中第一個元素表示要微分的數學函數的值,第二個元素是一個函數,在反向模式自動微分的反向傳遞 (即「自訂梯度」函數) 中呼叫。

作為 fun 輸出第二個元素傳回的函數可以封閉 (close over) 在評估要微分的函數時計算的中間值。也就是說,使用詞法閉包在反向模式自動微分的前向傳遞和反向傳遞之間共享工作。但是,它不能執行依賴於封閉的中間值或其餘切引數值的 Python 控制流程;如果函數包含此類控制流程,則會引發錯誤。

參數:

fun – 一個 Python 可呼叫物件,指定要微分的數學函數及其反向模式微分規則。它應傳回一個配對,其中包含一個輸出值和一個代表自訂梯度函數的 Python 可呼叫物件。

傳回:

一個 Python 可呼叫物件,它接受與 fun 相同的引數,並傳回由 fun 的輸出配對的第一個元素指定的輸出值。

例如

>>> @jax.custom_gradient
... def f(x):
...   return x ** 2, lambda g: (g * x,)
...
>>> print(f(3.))
9.0
>>> print(jax.grad(f)(3.))
3.0

一個具有兩個引數的函數的範例,因此 VJP 函數必須傳回長度為二的元組

>>> @jax.custom_gradient
... def f(x, y):
...   return x * y, lambda g: (g * y, g * x)
...
>>> print(f(3., 4.))
12.0
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
(Array(4., dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True))