建構於 JAX 之上#

學習進階 JAX 用法的一個好方法是了解其他程式庫如何使用 JAX,包括它們如何將程式庫整合到其 API 中、它在數學上增加了哪些功能,以及它如何在其他程式庫中用於計算加速。

以下是一些範例,說明如何使用 JAX 的功能來定義跨多個領域和軟體套件的加速計算。

梯度計算#

簡易的梯度計算是 JAX 的主要功能。在 JaxOpt 程式庫 中,value 和 grad 直接在 其原始碼 中的多種最佳化演算法中供使用者使用。

同樣地,上面提到的 Dynamax Optax 配對是梯度啟用歷史上具有挑戰性的估計方法的一個範例,使用 Optax 的最大似然期望值

跨多個裝置在單一核心上的計算加速#

然後可以編譯在 JAX 中定義的模型,以透過 JIT 編譯實現單一計算加速。然後,相同的編譯程式碼可以傳送到 CPU 裝置、GPU 或 TPU 裝置以獲得額外的加速,通常不需要額外的變更。這允許從開發到生產的順暢工作流程。在 Dynamax 中,線性狀態空間模型求解器中計算量大的部分已進行 jitted。一個更複雜的範例來自 PyTensor,它動態編譯 JAX 函數,然後 jits 建構的函數

使用平行化的單一和多電腦加速#

JAX 的另一個優點是使用 pmapvmap 函數呼叫或裝飾器平行化計算的簡便性。在 Dynamax 中,狀態空間模型透過 VMAP 裝飾器 進行平行化,多物件追蹤是這種用例的一個實際範例。

將 JAX 程式碼整合到您或您使用者的工作流程中#

JAX 具有相當的可組合性,並且可以多種方式使用。JAX 可以與獨立模式一起使用,使用者在其中定義所有自己的計算。但是,其他模式,例如使用建立在 jax 之上的程式庫,這些程式庫提供特定功能。這些可以是定義特定模型類型(例如神經網路或狀態空間模型或其他模型)的程式庫,或提供特定功能(例如最佳化)的程式庫。以下是每種模式的更具體範例。

直接使用#

可以如本網站所示直接匯入和使用 Jax 來「從頭開始」建構模型,例如在 JAX 教學使用 JAX 的神經網路 中。如果您無法為您的特定挑戰找到預先建置的程式碼,或者如果您希望減少程式碼庫中的依賴項數量,這可能是最佳選擇。

具有公開 JAX 的可組合領域特定程式庫#

另一種常見方法是提供預先建置功能的套件,無論是模型定義還是某種類型的計算。然後可以混合和匹配這些套件的組合,以實現完整的端對端工作流程,在其中定義模型並估計其參數。

一個範例是 Flax,它簡化了神經網路的建構。然後,Flax 通常與 Optax 配對,其中 Flax 定義神經網路架構,而 Optax 提供最佳化和模型擬合功能。

另一個是 Dynamax,它允許輕鬆定義狀態空間模型。透過 Dynamax,可以使用 使用 Optax 的最大似然法 估計參數,或者可以使用 來自 Blackjax 的 MCMC 估計完整貝氏後驗。

完全對使用者隱藏 JAX#

其他程式庫選擇將 JAX 完全包裝在其模型特定的 API 中。PyMC 和 Pytensor 就是一個範例,使用者可能永遠不會「看到」直接的 JAX,而是使用 PyMC 特定 API 包裝 JAX 函數