外部函式介面 (FFI)#

本教學課程需要 JAX v0.4.31 或更新版本。

雖然可以使用 JAX 內建的 jax.numpyjax.lax 介面輕鬆且有效率地實作各種數值運算,但有時透過「外部函式介面」(FFI) 明確呼叫外部編譯程式庫可能會很有用。當特定運算先前已在最佳化的 C 或 CUDA 程式庫中實作,並且直接使用 JAX 重新實作這些計算並非易事時,這可能特別有用,但對於最佳化 JAX 程式的執行階段或記憶體效能也可能很有用。儘管如此,FFI 通常應被視為最後的手段,因為後端的 XLA 編譯器或提供較低階控制的 Pallas 核心語言,通常以較低的開發和維護成本產生高效能的程式碼。

在考慮使用 FFI 時,應考慮到的一點是,JAX 不會自動知道如何透過外部函式進行微分。這表示如果您想要將 JAX 的自動微分功能與外部函式一起使用,您也需要提供相關微分規則的實作。我們將在下面討論一些可能的方法,但務必從一開始就指出此限制!

JAX 的 FFI 支援分為兩個部分

  1. XLA 的僅標頭 C++ 程式庫,從 v0.4.29 開始作為 JAX 的一部分封裝,或可從 openxla/xla 專案取得,以及

  2. Python 前端,可在 jax.ffi 子模組中使用。

在本教學課程中,我們將透過簡單的範例示範如何使用這兩個組件,然後繼續討論針對更複雜使用案例的較低階擴充功能。我們先介紹 CPU 上的 FFI,然後在下面討論 GPU 或多裝置環境的推廣。

本範例和其他更進階使用案例的端對端程式碼可在 GitHub 上的 JAX FFI 範例專案中找到:examples/ffi 在 JAX 儲存庫中

由於我們將在本教學課程的結尾示範如何對 FFI 呼叫進行分片,因此我們先設定我們的環境,讓 JAX 將其視為具有多個 CPU

import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"

簡單範例#

為了示範 FFI 介面的使用,我們將實作一個簡單的「均方根 (RMS)」正規化函式。RMS 正規化接受形狀為 \((N,)\) 的陣列 \(x\) 並傳回

\[ y_n = \frac{x_n}{\sqrt{\frac{1}{N}\sum_{n=1}^N {x_n}^2 + \epsilon}} \]

其中 \(\epsilon\) 是用於數值穩定性的調整參數。

這是一個有點蠢的範例,因為可以使用 JAX 輕鬆實作如下

import jax
import jax.numpy as jnp


def rms_norm_ref(x, eps=1e-5):
  scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
  return x / scale

但是,它只是稍微複雜一些,足以用於示範 FFI 的一些關鍵細節,同時仍然易於理解。我們將使用此參考實作來測試我們下面的 FFI 版本。

後端程式碼#

首先,我們需要在 C++ 中實作 RMS 正規化,我們將使用 FFI 公開它。這並不是要特別高效能,但您可以想像,如果您在 C++ 程式庫中對 RMS 正規化進行了一些新的更好實作,它可能會具有如下介面。因此,以下是 C++ 中 RMS 正規化的簡單實作

#include <cmath>
#include <cstdint>

float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
  float sm = 0.0f;
  for (int64_t n = 0; n < size; ++n) {
    sm += x[n] * x[n];
  }
  float scale = 1.0f / std::sqrt(sm / float(size) + eps);
  for (int64_t n = 0; n < size; ++n) {
    y[n] = x[n] * scale;
  }
  return scale;
}

並且,對於我們的範例,這是我們想要透過 FFI 公開給 JAX 的函式。

C++ 介面#

為了將我們的程式庫函式公開給 JAX 和 XLA,我們需要使用 xla/ffi/api 目錄中僅標頭程式庫提供的 API,撰寫一個精簡的包裝函式。XLA 專案。如需有關此介面的更多資訊,請參閱 XLA 自訂呼叫文件。完整的原始碼清單可以從 這裡 下載,但此處重現了關鍵實作細節

#include <functional>
#include <numeric>
#include <utility>

#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

namespace ffi = xla::ffi;

// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.
// In this example, we treat all leading dimensions as batch dimensions, so this
// function returns the total number of elements in the buffer, and the size of
// the last dimension.
template <ffi::DataType T>
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
  auto dims = buffer.dimensions();
  if (dims.size() == 0) {
    return std::make_pair(0, 0);
  }
  return std::make_pair(buffer.element_count(), dims.back());
}

// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
                       ffi::ResultBuffer<ffi::F32> y) {
  auto [totalSize, lastDim] = GetDims(x);
  if (lastDim == 0) {
    return ffi::Error::InvalidArgument("RmsNorm input must be an array");
  }
  for (int64_t n = 0; n < totalSize; n += lastDim) {
    ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
  }
  return ffi::Error::Success();
}

// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`
// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.
XLA_FFI_DEFINE_HANDLER_SYMBOL(
    RmsNorm, RmsNormImpl,
    ffi::Ffi::Bind()
        .Attr<float>("eps")
        .Arg<ffi::Buffer<ffi::F32>>()  // x
        .Ret<ffi::Buffer<ffi::F32>>()  // y
);

從底部開始,我們使用 XLA 提供的巨集 XLA_FFI_DEFINE_HANDLER_SYMBOL 來產生一些樣板程式碼,這些程式碼將展開為一個名為 RmsNorm 且具有適當簽名的函式。但是,這裡重要的內容都在對 ffi::Ffi::Bind() 的呼叫中,我們在其中定義輸入和輸出型別,以及任何參數的型別。

然後,在 RmsNormImpl 中,我們接受 ffi::Buffer 引數,其中包含有關緩衝區形狀和指向基礎資料的指標的資訊。在此實作中,我們將緩衝區的所有前導維度視為批次維度,並在最後一個軸上執行 RMS 正規化。GetDims 是一個輔助函式,用於支援此批次處理行為。我們在 下方 更詳細地討論了此批次處理行為,但一般概念是,在輸入引數的最左側維度中透明地處理批次處理可能很有用。在這種情況下,我們將除最後一個軸以外的所有軸都視為批次維度,但其他外部函式可能需要不同數量的非批次維度。

建置和註冊 FFI 處理常式#

現在我們已經實作了我們最小的 FFI 包裝函式,我們需要將此函式 (RmsNorm) 公開給 Python。在本教學課程中,我們將 RmsNorm 編譯成共用程式庫,並使用 ctypes 載入它,但另一種常見模式是使用 nanobindpybind11,如下所述。

為了編譯共用程式庫,我們在這裡使用 CMake,但您應該能夠使用您最喜歡的建置系統,而不會遇到太多麻煩。

!cmake -DCMAKE_BUILD_TYPE=Release -B ffi/_build ffi
!cmake --build ffi/_build
!cmake --install ffi/_build
隱藏程式碼儲存格輸出
-- The CXX compiler identification is GNU 11.4.0
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Python: /home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/bin/python3.10 (found suitable version "3.10.15", minimum required is "3.8") found components: Interpreter Development.Module
<string>:1: DeprecationWarning: jax.extend.ffi.include_dir is deprecated, use jax.ffi.include_dir instead.
-- XLA include directory: /home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include
-- Configuring done (1.3s)
-- Generating done (0.0s)
-- Build files have been written to: /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/_build
[ 50%] Building CXX object CMakeFiles/rms_norm.dir/rms_norm.cc.o
In file included from /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/rms_norm.cc:24:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:654:68: warning: always_inline’ function might not be inlinable []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wattributes-Wattributes]8;;]
  654 | _ATTRIBUTE_ALWAYS_INLINE std::optional<Buffer<dtype, rank>> DecodeBuffer(
      |                                                             ^~~~~~~~~~~~
In file included from /home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:48,
                 from /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/rms_norm.cc:24:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h: In function ‘std::ostream& operator<<(std::ostream&, XLA_FFI_ExecutionStage)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h:180:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  180 | }
      | ^
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h: In function ‘std::ostream& operator<<(std::ostream&, XLA_FFI_AttrType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h:166:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  166 | }
      | ^
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h: In function ‘std::ostream& operator<<(std::ostream&, XLA_FFI_DataType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/api.h:153:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  153 | }
      | ^
In file included from /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/rms_norm.cc:24:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h: In function ‘std::ostream& xla::ffi::operator<<(std::ostream&, XLA_FFI_ArgType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:722:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  722 | }
      | ^
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h: In function ‘std::ostream& xla::ffi::operator<<(std::ostream&, XLA_FFI_RetType)’:
/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jaxlib/include/xla/ffi/api/ffi.h:797:1: warning: control reaches end of non-void function []8;;https://gcc.gnu.org/onlinedocs/gcc/Warning-Options.html#index-Wreturn-type-Wreturn-type]8;;]
  797 | }
      | ^
[100%] Linking CXX shared library librms_norm.so
[100%] Built target rms_norm
-- Install configuration: "Release"
-- Installing: /home/docs/checkouts/readthedocs.org/user_builds/jax/checkouts/latest/docs/ffi/librms_norm.so

有了這個編譯後的程式庫,我們現在需要透過 register_ffi_target() 函式向 XLA 註冊此處理常式。此函式預期我們的處理常式 (指向 C++ 函式 RmsNorm 的函式指標) 包裝在 PyCapsule 中。JAX 提供了一個輔助函式 pycapsule() 來協助完成此操作

import ctypes
from pathlib import Path

path = next(Path("ffi").glob("librms_norm*"))
rms_norm_lib = ctypes.cdll.LoadLibrary(path)
jax.ffi.register_ffi_target(
    "rms_norm", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")

提示

如果您熟悉舊版「自訂呼叫」API,值得注意的是,您也可以使用 register_ffi_target(),透過手動指定關鍵字引數 api_version=0 來註冊自訂呼叫目標。register_ffi_target() 的預設 api_version1,也就是我們在此使用的新「型別化」FFI API。

替代方法:將處理常式公開給 Python 的常見替代模式是使用 nanobindpybind11 來定義一個可以匯入的小型 Python 擴充功能。對於我們這裡的範例,nanobind 程式碼將是

#include <type_traits>

#include "nanobind/nanobind.h"
#include "xla/ffi/api/c_api.h"

namespace nb = nanobind;

template <typename T>
nb::capsule EncapsulateFfiCall(T *fn) {
  // This check is optional, but it can be helpful for avoiding invalid handlers.
  static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
                "Encapsulated function must be and XLA FFI handler");
  return nb::capsule(reinterpret_cast<void *>(fn));
}

NB_MODULE(rms_norm, m) {
  m.def("rms_norm", []() { return EncapsulateFfiCall(RmsNorm); });
}

然後,在 Python 中,我們可以使用以下程式碼註冊此處理常式

# Assuming that we compiled a nanobind extension called `rms_norm`:
import rms_norm as rms_norm_lib

jax.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu")

前端程式碼#

現在我們已經註冊了我們的 FFI 處理常式,使用 ffi_call() 函式從 JAX 呼叫我們的 C++ 程式庫非常簡單

import numpy as np


def rms_norm(x, eps=1e-5):
  # We only implemented the `float32` version of this function, so we start by
  # checking the dtype. This check isn't strictly necessary because type
  # checking is also performed by the FFI when decoding input and output
  # buffers, but it can be useful to check types in Python to raise more
  # informative errors.
  if x.dtype != jnp.float32:
    raise ValueError("Only the float32 dtype is implemented by rms_norm")

  call = jax.ffi.ffi_call(
    # The target name must be the same string as we used to register the target
    # above in `register_custom_call_target`
    "rms_norm",

    # In this case, the output of our FFI function is just a single array with
    # the same shape and dtype as the input. We discuss a case with a more
    # interesting output type below.
    jax.ShapeDtypeStruct(x.shape, x.dtype),

    # The `vmap_method` parameter controls this function's behavior under `vmap`
    # as discussed below.
    vmap_method="broadcast_all",
  )

  # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
  # the attribute `eps`. Our FFI function expects this to have the C++ `float`
  # type (which corresponds to numpy's `float32` type), and it must be a
  # static parameter (i.e. not a JAX array).
  return call(x, eps=np.float32(eps))


# Test that this gives the same result as our reference implementation
x = jnp.linspace(-0.5, 0.5, 32).reshape((8, 4))
np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)

此程式碼儲存格包含許多內嵌註解,應說明此處發生的大部分情況,但有幾點值得明確強調。此處的大部分繁重工作是由 ffi_call() 函式完成的,它告訴 JAX 如何針對特定輸入集呼叫外部函式。務必注意,ffi_call() 的第一個引數必須是一個字串,該字串與我們在呼叫上面的 register_custom_call_target 時使用的目標名稱相符。

任何屬性 (使用上面的 C++ 包裝函式中的 Attr 定義) 都應作為關鍵字引數傳遞給 ffi_call()。請注意,我們明確地將 eps 轉換為 np.float32,因為我們的 FFI 程式庫預期為 C float,而且我們不能在此處使用 jax.numpy,因為這些參數必須是靜態引數。

ffi_call()vmap_method 引數定義了此 FFI 呼叫如何與 vmap() 互動,如下所述。

提示

如果您熟悉早期的「自訂呼叫」介面,您可能會感到驚訝,我們沒有將問題維度作為參數 (批次大小等) 傳遞給 ffi_call()。在此早期的 API 中,後端沒有接收有關輸入陣列中繼資料的機制,但由於 FFI 將維度資訊與 Buffer 物件包含在一起,因此我們不再需要在降低時使用 Python 計算此資訊。此變更的一個主要好處是 ffi_call() 可以開箱即用地支援一些簡單的 vmap() 語意,如下所述。

使用 vmap 進行批次處理#

ffi_call() 使用 vmap_method 參數,開箱即用地支援一些簡單的 vmap() 語意。pure_callback() 的文件提供了有關 vmap_method 參數的更多詳細資訊,並且相同的行為適用於 ffi_call()

最簡單的 vmap_method"sequential"。在這種情況下,當進行 vmap 時,ffi_call 將被重寫為 scan(),其中主體中包含 ffi_call。此實作是通用用途的,但它的平行化效果不佳。許多 FFI 呼叫提供更有效率的批次處理行為,並且在某些簡單的情況下,可以使用 "expand_dims""broadcast_all" 方法來公開更好的實作。

在這種情況下,由於我們只有一個輸入引數,因此 "expand_dims""broadcast_all" 實際上具有相同的行為。使用這些方法所需的特定假設是外部函式知道如何處理批次維度。另一種說法是,在批次輸入上呼叫 ffi_call 的結果假定等於堆疊 ffi_call 對批次輸入中每個元素的重複應用,大致如下

ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])

提示

請注意,當我們有多個輸入引數時,情況會變得有點複雜。為了簡單起見,我們將在本教學課程中始終使用 "broadcast_all",這保證所有輸入都將廣播以具有相同的批次維度,但也可以實作外部函式來處理 "expand_dims" 方法。pure_callback() 的文件包含一些這方面的範例

我們的 rms_norm 實作具有適當的語意,並且開箱即用支援具有 vmap_method="broadcast_all"vmap

np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)

我們可以檢查 jaxprvmap()rms_norm,以確認它沒有使用 scan() 重新編寫

jax.make_jaxpr(jax.vmap(rms_norm))(x)
{ lambda ; a:f32[8,4]. let
    b:f32[8,4] = ffi_call[
      attributes=(('eps', np.float32(1e-05)),)
      custom_call_api_version=4
      has_side_effect=False
      input_layouts=((1, 0),)
      input_output_aliases=()
      legacy_backend_config=None
      output_layouts=((1, 0),)
      result_avals=(ShapedArray(float32[8,4]),)
      target_name=rms_norm
      vectorized=Deprecated
      vmap_method=broadcast_all
    ] a
  in (b,) }

使用 vmap_method="sequential"vmap 處理 ffi_call 將會回退到 jax.lax.scan(),其中主體中包含 ffi_call

def rms_norm_sequential(x, eps=1e-5):
  return jax.ffi.ffi_call(
    "rms_norm",
    jax.ShapeDtypeStruct(x.shape, x.dtype),
    vmap_method="sequential",
  )(x, eps=np.float32(eps))


jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)
{ lambda ; a:f32[8,4]. let
    b:f32[8,4] = scan[
      _split_transpose=False
      jaxpr={ lambda ; c:f32[4]. let
          d:f32[4] = ffi_call[
            attributes=(('eps', np.float32(1e-05)),)
            custom_call_api_version=4
            has_side_effect=False
            input_layouts=((0,),)
            input_output_aliases=()
            legacy_backend_config=None
            output_layouts=((0,),)
            result_avals=(ShapedArray(float32[4]),)
            target_name=rms_norm
            vectorized=Deprecated
            vmap_method=sequential
          ] c
        in (d,) }
      length=8
      linear=(False,)
      num_carry=0
      num_consts=0
      reverse=False
      unroll=1
    ] a
  in (b,) }

如果您的外部函式提供此簡單 vmap_method 參數不支援的有效批次處理規則,則也可能使用實驗性 custom_vmap 介面定義更彈性的自訂 vmap 規則,但最好也在 JAX 問題追蹤器上開啟一個問題來描述您的使用案例。

微分#

與批次處理不同,ffi_call() 不提供對外部函式自動微分 (AD) 的任何預設支援。就 JAX 而言,外部函式是一個黑箱,無法檢查以確定微分時的適當行為。因此,定義自訂導數規則是 ffi_call() 使用者的責任。

有關自訂導數規則的更多詳細資訊,請參閱 自訂導數教學課程,但用於實作外部函式微分的最常見模式是定義一個 custom_vjp(),它本身會呼叫外部函式。在這種情況下,我們實際上定義了兩個新的 FFI 呼叫

  1. rms_norm_fwd 傳回兩個輸出:(a) 「原始」結果,以及 (b) 在反向傳遞中使用的「殘差」。

  2. rms_norm_bwd 接受殘差和輸出共切線,並傳回輸入共切線。

我們不會深入探討 RMS 正規化反向傳播的細節,但可以參考C++ 原始碼,以了解這些函數在後端的實作方式。這裡主要強調的重點是,計算出的「殘差 (residual)」與原始輸出 (primal output) 的形狀不同,因此在呼叫 ffi_call() 來執行 res_norm_fwd 時,輸出類型有兩個具有不同形狀的元素。

這個自訂導數規則可以如下方式連結 (wired in)

jax.ffi.register_ffi_target(
  "rms_norm_fwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu"
)
jax.ffi.register_ffi_target(
  "rms_norm_bwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu"
)


def rms_norm_fwd(x, eps=1e-5):
  y, res = jax.ffi.ffi_call(
    "rms_norm_fwd",
    (
      jax.ShapeDtypeStruct(x.shape, x.dtype),
      jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
    ),
    vmap_method="broadcast_all",
  )(x, eps=np.float32(eps))
  return y, (res, x)


def rms_norm_bwd(eps, res, ct):
  del eps
  res, x = res
  assert res.shape == ct.shape[:-1]
  assert x.shape == ct.shape
  return (
    jax.ffi.ffi_call(
      "rms_norm_bwd",
      jax.ShapeDtypeStruct(ct.shape, ct.dtype),
      vmap_method="broadcast_all",
    )(res, x, ct),
  )


rms_norm = jax.custom_vjp(rms_norm, nondiff_argnums=(1,))
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)

# Check that this gives the right answer when compared to the reference version
ct_y = jnp.ones_like(x)
np.testing.assert_allclose(
  jax.vjp(rms_norm, x)[1](ct_y), jax.vjp(rms_norm_ref, x)[1](ct_y), rtol=1e-5
)

至此,我們可以針對許多 JAX 應用程式透明地使用新的 rms_norm 函數,並且它會在標準 JAX 函數轉換 (例如 vmap()grad()) 下進行適當的轉換。這個範例不支援前向模式自動微分 (AD) (jax.jvp(),例如),因為 custom_vjp() 僅限於反向模式。JAX 目前尚未公開同時自訂前向模式和反向模式 AD 的公共 API,但此類 API 已在規劃中,因此如果您在實務上遇到此限制,請開啟 issue 描述您的使用案例。

這個範例不支援的另一個 JAX 功能是更高階的 AD。可以透過將上面的 res_norm_bwd 函數包裝在 jax.custom_jvp()jax.custom_vjp() 裝飾器中來解決這個問題,但我們不會在此深入探討這個進階使用案例的細節。

GPU 上的 FFI 呼叫#

到目前為止,我們僅針對在 CPU 上執行的外部函數進行介接 (interfacing),但 JAX 的 FFI 也支援呼叫 GPU 程式碼。由於此文件頁面是在沒有 GPU 存取權限的機器上自動產生的,因此我們無法在此處執行任何 GPU 專用範例,但我們將討論重點。

當我們為 CPU 定義 FFI 包裝器時,我們使用的函數簽名是

ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
                       ffi::ResultBuffer<ffi::F32> y)

若要更新此簽名以與 CUDA 核心 (kernel) 介接,則簽名會變成

ffi::Error RmsNormImpl(cudaStream_t stream, float eps,
                       ffi::Buffer<ffi::F32> x,
                       ffi::ResultBuffer<ffi::F32> y)

並且處理常式 (handler) 定義已更新,在其綁定 (binding) 中包含一個 Ctx

XLA_FFI_DEFINE_HANDLER(
    RmsNorm, RmsNormImpl,
    ffi::Ffi::Bind()
        .Ctx<ffi::PlatformStream<cudaStream_t>>()
        .Attr<float>("eps")
        .Arg<ffi::Buffer<ffi::F32>>()  // x
        .Ret<ffi::Buffer<ffi::F32>>()  // y
);

然後,RmsNormImpl 可以使用 CUDA stream 來啟動 CUDA 核心。

在前端,註冊程式碼將會更新以指定適當的平台

jax.ffi.register_ffi_target(
  "rms_norm_cuda", rms_norm_lib_cuda.rms_norm(), platform="CUDA"
)

支援多個平台#

為了支援在 GPU 和 CPU 上執行我們的 rms_norm 函數,我們可以將上述實作與 jax.lax.platform_dependent() 函數結合使用

def rms_norm_cross_platform(x, eps=1e-5):
  assert x.dtype == jnp.float32
  out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)

  def impl(target_name):
    return lambda x: jax.ffi.ffi_call(
      target_name,
      out_type,
      vmap_method="broadcast_all",
    )(x, eps=np.float32(eps))

  return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))


np.testing.assert_allclose(rms_norm_cross_platform(x), rms_norm_ref(x), rtol=1e-5)

此版本的函數將根據執行時平台呼叫適當的 FFI 目標。

順帶一提,有趣的是要注意到,雖然 jaxpr 和降低後的 HLO 都包含對兩個 FFI 目標的引用

jax.make_jaxpr(rms_norm_cross_platform)(x)
{ lambda ; a:f32[8,4]. let
    b:i32[] = platform_index[has_default=False platforms=(('cpu',), ('cuda',))] 
    c:i32[] = clamp 0 b 1
    d:f32[8,4] = cond[
      branches=(
        { lambda ; e:f32[8,4]. let
            f:f32[8,4] = ffi_call[
              attributes=(('eps', np.float32(1e-05)),)
              custom_call_api_version=4
              has_side_effect=False
              input_layouts=((1, 0),)
              input_output_aliases=()
              legacy_backend_config=None
              output_layouts=((1, 0),)
              result_avals=(ShapedArray(float32[8,4]),)
              target_name=rms_norm
              vectorized=Deprecated
              vmap_method=broadcast_all
            ] e
          in (f,) }
        { lambda ; g:f32[8,4]. let
            h:f32[8,4] = ffi_call[
              attributes=(('eps', np.float32(1e-05)),)
              custom_call_api_version=4
              has_side_effect=False
              input_layouts=((1, 0),)
              input_output_aliases=()
              legacy_backend_config=None
              output_layouts=((1, 0),)
              result_avals=(ShapedArray(float32[8,4]),)
              target_name=rms_norm_cuda
              vectorized=Deprecated
              vmap_method=broadcast_all
            ] g
          in (h,) }
      )
    ] c a
  in (d,) }
print(jax.jit(rms_norm_cross_platform).lower(x).as_text().strip())
module @jit_rms_norm_cross_platform attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<8x4xf32>) -> (tensor<8x4xf32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<0> : tensor<i32>
    %c_0 = stablehlo.constant dense<0> : tensor<i32>
    %c_1 = stablehlo.constant dense<1> : tensor<i32>
    %0 = stablehlo.clamp %c_0, %c, %c_1 : tensor<i32>
    %1 = "stablehlo.case"(%0) ({
      %2 = stablehlo.custom_call @rms_norm(%arg0) {backend_config = "", mhlo.backend_config = {eps = 9.99999974E-6 : f32}, operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<8x4xf32>) -> tensor<8x4xf32>
      stablehlo.return %2 : tensor<8x4xf32>
    }, {
      %2 = stablehlo.custom_call @rms_norm_cuda(%arg0) {backend_config = "", mhlo.backend_config = {eps = 9.99999974E-6 : f32}, operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<8x4xf32>) -> tensor<8x4xf32>
      stablehlo.return %2 : tensor<8x4xf32>
    }) : (tensor<i32>) -> tensor<8x4xf32>
    return %1 : tensor<8x4xf32>
  }
}

但在函數編譯完成時,已選取適當的 FFI

print(jax.jit(rms_norm_cross_platform).lower(x).as_text(dialect="hlo").strip())
HloModule jit_rms_norm_cross_platform, entry_computation_layout={(f32[8,4]{1,0})->f32[8,4]{1,0}}

ENTRY main.3 {
  Arg_0.1 = f32[8,4]{1,0} parameter(0)
  ROOT custom-call.2 = f32[8,4]{1,0} custom-call(Arg_0.1), custom_call_target="rms_norm", operand_layout_constraints={f32[8,4]{1,0}}, api_version=API_VERSION_TYPED_FFI
}

並且使用 jax.lax.platform_dependent() 不會有任何執行時額外負擔,而且編譯後的程式碼不會包含對不可用 FFI 目標的任何引用。

分片 (Sharding)#

大多數 JAX 的大規模使用者都會使用其 API 進行跨多個裝置的分散式運算。如平行程式設計簡介中所討論的,JAX 中的平行處理是由跨裝置分片資料來控制的,並且大多數 JAX 運算都可以在任何支援的平行程式設計範例 (從自動到完全手動) 中使用。但是,對於 FFI 呼叫而言,情況稍微複雜一些。由於 FFI 呼叫的內部機制對於 JAX 和 XLA 都是不透明的,因此當資料被分片時,FFI 呼叫通常不會顯示最佳 (甚至良好) 的效能。

在深入探討 FFI 細節之前,讓我們先考慮 RMS 正規化的純 JAX 參考實作 (本文檔頂部定義的 rms_norm_ref 函數) 在分片輸入下的行為。如上所述,我們的實作將輸入的所有前導軸視為批次維度,並且沿著最後一個軸執行正規化。這表示如果資料沿著任何批次維度分片,但在最後一個維度上複製,則不需要任何通訊。這可以透過沿著第一個維度分片我們上面 2 維的測試資料,並檢查編譯後的 HLO 中是否有 all-gatherall-reduce 等運算來觀察到。

from jax.sharding import PartitionSpec as P

assert len(jax.devices()) == 4  # Set using the XLA_FLAGS environment variable
mesh = jax.make_mesh((4,), ("x",))

batch_shd = jax.NamedSharding(mesh, P("x", None))
x_batch_shd = jax.device_put(x, batch_shd)
hlo_batch = jax.jit(rms_norm_ref, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text()
assert "all-" not in hlo_batch

但是,如果資料沿著最後一個軸分片,則需要通訊 (在本例中為 all-reduce) 來計算正規化中的總和

data_shd = jax.NamedSharding(mesh, P(None, "x"))
x_data_shd = jax.device_put(x, data_shd)
hlo_data = jax.jit(rms_norm_ref, out_shardings=data_shd).lower(x_data_shd).compile().as_text()
assert "all-reduce" in hlo_data

現在,如果我們天真地嘗試使用相同模型的 FFI 版本,它可以正常執行並獲得正確的答案

output = jax.jit(rms_norm, out_shardings=batch_shd)(x_batch_shd)
np.testing.assert_allclose(output, rms_norm_ref(x), rtol=1e-5)

但是,如果您查看編譯後的 HLO (為了清楚起見,省略了輔助函數),您會看到

  1. 資料首先透過 all-gather 運算完全複製到每個裝置上,

  2. FFI 呼叫在每個裝置上的完整資料集上執行,並且

  3. 輸出被切片以丟棄未使用的部分。

hlo = jax.jit(rms_norm, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text().strip()
print(hlo.split("\n\n")[-1])
ENTRY %main.5_spmd (param: f32[2,4]) -> f32[2,4] {
  %param = f32[2,4]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}, metadata={op_name="x"}
  %all-gather = f32[8,4]{1,0} all-gather(f32[2,4]{1,0} %param), channel_id=1, replica_groups=[1,4]<=[4], dimensions={0}, use_global_device_ids=true, metadata={op_name="jit(rms_norm)/jit(main)/ffi_call" source_file="/tmp/ipykernel_924/3540880311.py" source_line=32}
  %custom-call.0 = f32[8,4]{1,0} custom-call(f32[8,4]{1,0} %all-gather), custom_call_target="rms_norm", operand_layout_constraints={f32[8,4]{1,0}}, api_version=API_VERSION_TYPED_FFI, metadata={op_name="jit(rms_norm)/jit(main)/ffi_call" source_file="/tmp/ipykernel_924/3540880311.py" source_line=32}, backend_config={eps = 9.99999974E-6 : f32}
  %partition-id = u32[] partition-id()
  ROOT %multiply_dynamic-slice_fusion = f32[2,4]{1,0} fusion(f32[8,4]{1,0} %custom-call.0, u32[] %partition-id), kind=kLoop, calls=%fused_computation
}

對我們來說,這顯然不是此函數的最佳分割方式,但這是 JAX/XLA 在給定資訊下可以做到的最好方式。

為了產生更好的分割邏輯,我們可以利用 shard_map()custom_partitioning(),我們在此討論這兩種選項。話雖如此,為所有輸入產生最佳分割並非易事,因為有時這需要演算法上的變更。具體來說,讓我們新增對「批次分割」的支援,這可以處理資料在批次維度上分片的情況,但最後一個維度上的分片始終需要重新分片。

使用 shard_map#

如果您透過 shard_map() 使用手動分片控制,則程式中的任何 FFI 呼叫都應已適當地分割

from functools import partial
from jax.experimental.shard_map import shard_map

@partial(shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None))
def rms_norm_shmap(x):
  return rms_norm(x)

np.testing.assert_allclose(rms_norm_shmap(x_batch_shd), rms_norm_ref(x), rtol=1e-5)
print(jax.jit(rms_norm_shmap, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text().strip())
HloModule jit_rms_norm_shmap, is_scheduled=true, entry_computation_layout={(f32[2,4]{1,0})->f32[2,4]{1,0}}, num_partitions=4

ENTRY %main.12_spmd (param: f32[2,4]) -> f32[2,4] {
  %param = f32[2,4]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}, metadata={op_name="x"}
  ROOT %custom-call.1 = f32[2,4]{1,0} custom-call(f32[2,4]{1,0} %param), custom_call_target="rms_norm", operand_layout_constraints={f32[2,4]{1,0}}, api_version=API_VERSION_TYPED_FFI, metadata={op_name="jit(rms_norm_shmap)/jit(main)/jit(shmap_body)/ffi_call" source_file="/tmp/ipykernel_924/3540880311.py" source_line=32}, backend_config={eps = 9.99999974E-6 : f32}
}

正如您在此程式中所見,如果輸入和輸出分片與 shard_map 規格相符,則不需要任何通訊,並且 FFI 呼叫會在適當分片的資料子集上執行。

您也可以使用分片與 shard_map 規格不符的輸入和輸出,但 (與 FFI 無關) 這將需要重新分片,如編譯後的 HLO 中的 all-to-all 運算所示

hlo_data_shmap = jax.jit(rms_norm_shmap, out_shardings=data_shd).lower(x_data_shd).compile().as_text()
assert "all-to-all" in hlo_data_shmap

使用 custom partitioning#

如果您無法使用 shard_map(),另一種方法是使用 custom_partitioning(),它透過 jax.jit() 支援自動平行化。custom_partitioning() 的運作方式是在 XLA 編譯器的分割傳遞 (partitioning pass) 中新增 Python 回呼 (callback),這允許非常彈性的邏輯,但也帶有一些粗糙之處。我們不會在此深入探討注意事項的太多細節,但您應該注意的主要問題是

  1. custom_partitioning 與 JAX 的 持久編譯快取 (Persistent compilation cache) 一起使用時,可能會導致意外的快取未命中 (cache misses)。可以使用 jax_remove_custom_partitioning_ptr_from_cache_key 組態標誌 (configuration flag) 來緩解此問題,但這也並非總是適當的。

  2. 偵錯 custom_partitioning 邏輯可能很繁瑣,因為 Python 錯誤並不總是會傳播,反而會導致您的 Python 處理程序退出。話雖如此,任何例外情況都會顯示在處理程序日誌中,因此您應該能夠在那裡追蹤到它們。

話雖如此,以下是我們如何使用 custom_partitioning() 包裝我們的 rms_norm FFI 實作

from jax.experimental.custom_partitioning import custom_partitioning

@partial(custom_partitioning, static_argnums=(1,))
def rms_norm_partitioned(x, eps=1e-5):
  return rms_norm(x, eps=eps)

def replicate_sharding_on_last_dim(mesh, sharding, target_info):
  # Our implementation supports trivial sharding on any batch dimensions, but the data
  # must be replicated on the last (non-batch) dimension.
  rank = len(target_info.shape)
  num_batch_dims = min(len(sharding.spec), rank - 1)

  # The Nones here indicate which dimensions should be replicated.
  names = tuple(sharding.spec[:num_batch_dims]) + (None,) * (rank - num_batch_dims)
  return jax.NamedSharding(mesh, P(*names))

def rms_norm_infer_sharding_from_operands(eps, mesh, args_info, result_info):
  del eps  # unused
  arg_info, = args_info
  result_sharding = replicate_sharding_on_last_dim(mesh, arg_info.sharding, result_info)

  # In this case, we only have a single output, but the return value from this function
  # must have the same pytree structure as the output from the underlying function
  # (`rms_norm` in this case).
  return result_sharding

def rms_norm_partition(eps, mesh, args_info, result_info):
  arg_info, = args_info
  arg_sharding = replicate_sharding_on_last_dim(mesh, arg_info.sharding, arg_info)
  result_sharding = replicate_sharding_on_last_dim(mesh, arg_info.sharding, result_info)

  # This is the function that computes the partitioned model on the appropriate subset
  # of the data.
  def partitioned_rms_norm(x):
    return rms_norm(x, eps=eps)

  # Note that the third element of our returned tuple must be the shardings for the
  # _outputs_ and its pytree structure must match the output of `rms_norm`. Similarly,
  # the fourth element must have the same pytree structure as the _inputs_ to
  # `rms_norm`. In this case, there is only one input, but it must be returned within
  # a `tuple` anyways.
  return mesh, partitioned_rms_norm, result_sharding, (arg_sharding,)

rms_norm_partitioned.def_partition(
    infer_sharding_from_operands=rms_norm_infer_sharding_from_operands,
    partition=rms_norm_partition,
)

output = jax.jit(rms_norm_partitioned, out_shardings=batch_shd)(x_batch_shd)
np.testing.assert_allclose(output, rms_norm_ref(x), rtol=1e-5)
print(jax.jit(rms_norm_partitioned, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text().strip())
HloModule jit__unnamed_wrapped_function_, is_scheduled=true, entry_computation_layout={(f32[2,4]{1,0})->f32[2,4]{1,0}}, num_partitions=4

ENTRY %main.5_spmd (param: f32[2,4]) -> f32[2,4] {
  %param = f32[2,4]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}, metadata={op_name="args[0]"}
  ROOT %custom-call.0 = f32[2,4]{1,0} custom-call(f32[2,4]{1,0} %param), custom_call_target="rms_norm", operand_layout_constraints={f32[2,4]{1,0}}, api_version=API_VERSION_TYPED_FFI, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/custom_partitioning" source_file="/tmp/ipykernel_924/1708142274.py" source_line=49}, backend_config={eps = 9.99999974E-6 : f32}
}

從上面編譯後的程式碼中您可以看到,當輸入在批次維度上分片時,此 custom_partitioning 邏輯產生的程式碼與上面的 shard_map 版本完全相同。

但是,值得注意的是,當輸入沿著資料維度分片時,行為是不同的。當在 shard_map 下使用時,資料會在批次維度上重新分片,而使用 custom_partitioning 時,資料會收集到每個裝置上。

hlo_data_partitioned = jax.jit(rms_norm_partitioned, out_shardings=data_shd).lower(x_data_shd).compile().as_text().strip()
assert "all-gather" in hlo_data_partitioned

為了也支援反向傳播的自動平行化,我們還需要為 rms_norm_fwdrms_norm_bwd 撰寫 (類似的) custom_partitioning() 規則,但我們將這些留給讀者作為練習。

進階主題#

本教學課程涵蓋了開始使用 JAX 的 FFI 所需的大多數基本步驟,但進階使用案例可能需要更多功能。我們將把這些主題留給未來的教學課程,但以下是一些可能有用的參考資料

  • 支援多種資料類型 (dtypes):在本教學課程的範例中,我們限制為僅支援 float32 輸入和輸出,但許多使用案例需要支援多種不同的輸入類型。處理這個問題的一個選項是為所有支援的輸入類型註冊不同的 FFI 目標,然後使用 Python 根據輸入類型為 jax.ffi.ffi_call() 選擇適當的目標。但是,根據支援案例的組合數學,這種方法可能會很快變得笨拙。因此,也可以定義 C++ 處理常式以接受 ffi::AnyBuffer 而不是 ffi::Buffer<Dtype>。然後,輸入緩衝區將包含一個 element_type() 方法,可用於在後端定義適當的資料類型分派 (dispatching) 邏輯。

  • 具狀態的外部函數 (Stateful foreign functions):也可以使用 FFI 來包裝具有相關狀態的函數。在 XLA 測試套件中包含了一個低階範例,未來的教學課程將包含更多詳細資訊。