jax.scipy.optimize.minimize#

jax.scipy.optimize.minimize(fun, x0, args=(), *, method, tol=None, options=None)[原始碼]#

最小化一個或多個變數的純量函數。

此函式的 API 與 SciPy 相符,但有一些細微的差異

  • 當需要時,fun 的梯度會使用 JAX 的自動微分支援自動計算。

  • 需要 method 引數。您必須指定求解器。

  • SciPy 介面中的各種可選引數尚未實作。

  • 由於線搜尋實作的差異,最佳化結果可能與 SciPy 不同。

minimize 支援 jit() 編譯。它尚不支援微分或多維陣列形式的引數,但計劃支援這兩者。

參數:
  • fun (Callable) – 要最小化的目標函數,fun(x, *args) -> float,其中 x 是形狀為 (n,) 的 1-D 陣列,而 args 是一個 tuple,其中包含完全指定函數所需的固定參數。fun 必須支援微分。

  • x0 (jax.Array) – 初始猜測。大小為 (n,) 的實數元素陣列,其中 n 是獨立變數的數量。

  • args (tuple) – 傳遞給目標函數的額外引數。

  • method (str) – 求解器型別。目前僅支援 "BFGS"

  • tol (float | None | None) – 終止容忍度。如需詳細控制,請使用求解器特定的選項。

  • options (Mapping[str, Any] | None | None) –

    求解器選項的字典。所有方法都接受以下通用選項

    • maxiter (int): 要執行的最大迭代次數。根據方法,每次迭代可能會使用多次函數評估。

返回:

一個 OptimizeResults 物件。

返回型別:

OptimizeResults