jax.hessian#

jax.hessian(fun, argnums=0, has_aux=False, holomorphic=False)[原始碼]#

fun 的 Hessian 矩陣,以密集陣列形式呈現。

參數:
  • fun (Callable) – 要計算 Hessian 矩陣的函式。其由 argnums 指定位置的參數應為陣列、純量或它們的標準 Python 容器。它應傳回陣列、純量或它們的標準 Python 容器。

  • argnums (int | Sequence[int]) – 選填,整數或整數序列。指定要針對哪個位置參數進行微分(預設為 0)。

  • has_aux (bool) – 選填,布林值。指示 fun 是否傳回一個配對,其中第一個元素被視為要微分的數學函式的輸出,而第二個元素是輔助資料。預設為 False。

  • holomorphic (bool) – 選填,布林值。指示 fun 是否保證為全純函數。預設為 False。

傳回值:

一個具有與 fun 相同參數的函式,它評估 fun 的 Hessian 矩陣。

傳回型別:

Callable

>>> import jax
>>>
>>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[   6.   -2.]
 [  -2. -480.]]

hessian() 是 Hessian 矩陣的常見定義的推廣,它支援巢狀 Python 容器(即 pytrees)作為輸入和輸出。jax.hessian(fun)(x) 的樹狀結構是透過形成 fun(x) 的結構的樹狀乘積,以及兩個 x 結構副本的樹狀乘積來給出的。兩個樹狀結構的樹狀乘積是透過將第一個樹的每個葉節點替換為第二個樹的副本來形成的。例如

>>> import jax.numpy as jnp
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': Array([[[ 2.,  0.], [ 0.,  0.]],
                         [[ 0.,  0.], [ 0., 12.]]], dtype=float32),
             'b': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32)},
       'b': {'a': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32),
             'b': Array([[[0.      , 0.      ], [0.      , 0.      ]],
                         [[0.      , 0.      ], [0.      , 3.843624]]], dtype=float32)}}}

因此,jax.hessian(fun)(x) 的樹狀結構中的每個葉節點都對應於 fun(x) 的一個葉節點和一對 x 的葉節點。對於 jax.hessian(fun)(x) 中的每個葉節點,如果 fun(x) 的對應陣列葉節點具有形狀 (out_1, out_2, ...),且 x 的對應陣列葉節點分別具有形狀 (in_1_1, in_1_2, ...)(in_2_1, in_2_2, ...),則 Hessian 葉節點具有形狀 (out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...)。換句話說,Python 樹狀結構表示 Hessian 矩陣的區塊結構,區塊由輸入和輸出 pytrees 決定。

特別是,當函式輸入 x 和輸出 fun(x) 各自都是單個陣列時,會產生一個陣列(不涉及 pytrees),如上面的 g 範例所示。如果 fun(x) 具有形狀 (out1, out2, ...),且 x 具有形狀 (in1, in2, ...),則 jax.hessian(fun)(x) 具有形狀 (out1, out2, ..., in1, in2, ..., in1, in2, ...)。要將 pytrees 展平為 1D 向量,請考慮使用 jax.flatten_util.flatten_pytree()