jax.linearize#
- jax.linearize(fun: Callable, *primals, has_aux: Literal[False] = False) tuple[Any, Callable] [原始碼]#
- jax.linearize(fun: Callable, *primals, has_aux: Literal[True]) tuple[Any, Callable, Any]
使用
jvp()
和部分求值,產生fun
的線性近似。- 參數:
fun – 要微分的函數。其引數應為陣列、純量,或陣列或純量的標準 Python 容器。它應該傳回陣列、純量,或陣列或純量的標準 Python 容器。
primals – 應該在其中評估
fun
的 Jacobian 的原始值。應為陣列、純量或其標準 Python 容器的元組。fun
的位置參數數量應等於元組的長度。has_aux – 選擇性,布林值。指出
fun
是否傳回一對值,其中第一個元素被視為要線性化的數學函數的輸出,第二個元素是輔助資料。預設為 False。
- 傳回值:
如果
has_aux
為False
,則傳回一對值,其中第一個元素是f(*primals)
的值,第二個元素是一個函數,該函數評估在primals
評估的fun
的(前向模式)Jacobian-向量乘積,而無需重新進行線性化工作。如果has_aux
為True
,則傳回(primals_out, lin_fn, aux)
元組,其中aux
是fun
傳回的輔助資料。
在計算值方面,
linearize()
的行為非常類似於 curriedjvp()
,其中以下兩個程式碼區塊計算相同的值y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
然而,不同之處在於
linearize()
使用部分求值,因此函數f
不會在呼叫f_jvp
時重新線性化。一般而言,這表示記憶體使用量會隨著計算大小而擴展,非常像反向模式。(實際上,linearize()
具有與vjp()
相似的簽名!)如果您想要多次套用
f_jvp
,即在相同的線性化點評估許多不同輸入切向量的 pushforward,則此函數主要很有用。此外,如果所有輸入切向量都一次已知,則使用vmap()
進行向量化可能更有效率,如下所示pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
透過像這樣一起使用
vmap()
和jvp()
,我們可以避免儲存的線性化記憶體成本,該成本會隨著計算深度而擴展,而linearize()
和vjp()
都會產生此成本。以下是使用
linearize()
的更完整範例>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) -5.007528 >>> print(f_jvp(4.)) -6.676704