jax.linearize#

jax.linearize(fun: Callable, *primals, has_aux: Literal[False] = False) tuple[Any, Callable][原始碼]#
jax.linearize(fun: Callable, *primals, has_aux: Literal[True]) tuple[Any, Callable, Any]

使用 jvp() 和部分求值,產生 fun 的線性近似。

參數:
  • fun – 要微分的函數。其引數應為陣列、純量,或陣列或純量的標準 Python 容器。它應該傳回陣列、純量,或陣列或純量的標準 Python 容器。

  • primals – 應該在其中評估 fun 的 Jacobian 的原始值。應為陣列、純量或其標準 Python 容器的元組。fun 的位置參數數量應等於元組的長度。

  • has_aux – 選擇性,布林值。指出 fun 是否傳回一對值,其中第一個元素被視為要線性化的數學函數的輸出,第二個元素是輔助資料。預設為 False。

傳回值:

如果 has_auxFalse,則傳回一對值,其中第一個元素是 f(*primals) 的值,第二個元素是一個函數,該函數評估在 primals 評估的 fun 的(前向模式)Jacobian-向量乘積,而無需重新進行線性化工作。如果 has_auxTrue,則傳回 (primals_out, lin_fn, aux) 元組,其中 auxfun 傳回的輔助資料。

在計算值方面,linearize() 的行為非常類似於 curried jvp(),其中以下兩個程式碼區塊計算相同的值

y, out_tangent = jax.jvp(f, (x,), (in_tangent,))

y, f_jvp = jax.linearize(f, x)
out_tangent = f_jvp(in_tangent)

然而,不同之處在於 linearize() 使用部分求值,因此函數 f 不會在呼叫 f_jvp 時重新線性化。一般而言,這表示記憶體使用量會隨著計算大小而擴展,非常像反向模式。(實際上,linearize() 具有與 vjp() 相似的簽名!)

如果您想要多次套用 f_jvp,即在相同的線性化點評估許多不同輸入切向量的 pushforward,則此函數主要很有用。此外,如果所有輸入切向量都一次已知,則使用 vmap() 進行向量化可能更有效率,如下所示

pushfwd = partial(jvp, f, (x,))
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))

透過像這樣一起使用 vmap()jvp(),我們可以避免儲存的線性化記憶體成本,該成本會隨著計算深度而擴展,而 linearize()vjp() 都會產生此成本。

以下是使用 linearize() 的更完整範例

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
...
>>> jax.jvp(f, (2.,), (3.,))
(Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True))
>>> y, f_jvp = jax.linearize(f, 2.)
>>> print(y)
3.2681944
>>> print(f_jvp(3.))
-5.007528
>>> print(f_jvp(4.))
-6.676704