建構於 JAX 之上#
學習進階 JAX 用法的一個好方法是了解其他程式庫如何使用 JAX,包括它們如何將程式庫整合到其 API 中、它在數學上增加了哪些功能,以及它如何在其他程式庫中用於計算加速。
以下是一些範例,說明如何使用 JAX 的功能來定義跨多個領域和軟體套件的加速計算。
梯度計算#
簡易的梯度計算是 JAX 的主要功能。在 JaxOpt 程式庫 中,value 和 grad 直接在 其原始碼 中的多種最佳化演算法中供使用者使用。
同樣地,上面提到的 Dynamax Optax 配對是梯度啟用歷史上具有挑戰性的估計方法的一個範例,使用 Optax 的最大似然期望值。
跨多個裝置在單一核心上的計算加速#
然後可以編譯在 JAX 中定義的模型,以透過 JIT 編譯實現單一計算加速。然後,相同的編譯程式碼可以傳送到 CPU 裝置、GPU 或 TPU 裝置以獲得額外的加速,通常不需要額外的變更。這允許從開發到生產的順暢工作流程。在 Dynamax 中,線性狀態空間模型求解器中計算量大的部分已進行 jitted。一個更複雜的範例來自 PyTensor,它動態編譯 JAX 函數,然後 jits 建構的函數。
使用平行化的單一和多電腦加速#
JAX 的另一個優點是使用 pmap
和 vmap
函數呼叫或裝飾器平行化計算的簡便性。在 Dynamax 中,狀態空間模型透過 VMAP 裝飾器 進行平行化,多物件追蹤是這種用例的一個實際範例。
將 JAX 程式碼整合到您或您使用者的工作流程中#
JAX 具有相當的可組合性,並且可以多種方式使用。JAX 可以與獨立模式一起使用,使用者在其中定義所有自己的計算。但是,其他模式,例如使用建立在 jax 之上的程式庫,這些程式庫提供特定功能。這些可以是定義特定模型類型(例如神經網路或狀態空間模型或其他模型)的程式庫,或提供特定功能(例如最佳化)的程式庫。以下是每種模式的更具體範例。
直接使用#
可以如本網站所示直接匯入和使用 Jax 來「從頭開始」建構模型,例如在 JAX 教學 或 使用 JAX 的神經網路 中。如果您無法為您的特定挑戰找到預先建置的程式碼,或者如果您希望減少程式碼庫中的依賴項數量,這可能是最佳選擇。
具有公開 JAX 的可組合領域特定程式庫#
另一種常見方法是提供預先建置功能的套件,無論是模型定義還是某種類型的計算。然後可以混合和匹配這些套件的組合,以實現完整的端對端工作流程,在其中定義模型並估計其參數。
一個範例是 Flax,它簡化了神經網路的建構。然後,Flax 通常與 Optax 配對,其中 Flax 定義神經網路架構,而 Optax 提供最佳化和模型擬合功能。
另一個是 Dynamax,它允許輕鬆定義狀態空間模型。透過 Dynamax,可以使用 使用 Optax 的最大似然法 估計參數,或者可以使用 來自 Blackjax 的 MCMC 估計完整貝氏後驗。