jax.scipy.optimize.minimize#
- jax.scipy.optimize.minimize(fun, x0, args=(), *, method, tol=None, options=None)[原始碼]#
最小化一個或多個變數的純量函數。
此函式的 API 與 SciPy 相符,但有一些細微的差異
當需要時,
fun
的梯度會使用 JAX 的自動微分支援自動計算。需要
method
引數。您必須指定求解器。SciPy 介面中的各種可選引數尚未實作。
由於線搜尋實作的差異,最佳化結果可能與 SciPy 不同。
minimize
支援jit()
編譯。它尚不支援微分或多維陣列形式的引數,但計劃支援這兩者。- 參數:
fun (Callable) – 要最小化的目標函數,
fun(x, *args) -> float
,其中x
是形狀為(n,)
的 1-D 陣列,而args
是一個 tuple,其中包含完全指定函數所需的固定參數。fun
必須支援微分。x0 (jax.Array) – 初始猜測。大小為
(n,)
的實數元素陣列,其中n
是獨立變數的數量。args (tuple) – 傳遞給目標函數的額外引數。
method (str) – 求解器型別。目前僅支援
"BFGS"
。tol (float | None | None) – 終止容忍度。如需詳細控制,請使用求解器特定的選項。
options (Mapping[str, Any] | None | None) –
求解器選項的字典。所有方法都接受以下通用選項
maxiter (int): 要執行的最大迭代次數。根據方法,每次迭代可能會使用多次函數評估。
- 返回:
一個
OptimizeResults
物件。- 返回型別: