GPU 效能提示#

本文檔著重於神經網路工作負載的效能提示

Matmul 精確度#

在最新的 GPU 世代,例如 Nvidia A100 世代或更新的世代,以 bfloat16 精確度執行大多數計算可能是個好主意。 例如,如果使用 Flax,請使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16) 實例化 Dense 層。以下是一些程式碼範例

XLA 效能旗標#

注意

JAX-Toolbox 也有關於 NVIDIA XLA 效能旗標的頁面。

XLA 旗標的存在和確切行為可能與 jaxlib 版本有關。

截至 jaxlib==0.4.18(於 2023 年 10 月 6 日發布),設定這些 XLA 旗標可以提高效能。 其中一些與 GPU 之間的通訊有關,因此僅在多個裝置上執行計算時才相關,而另一些則與每個裝置上的程式碼產生有關。

其中一些旗標可能在未來的版本中預設設定。

這些旗標可以透過 XLA_FLAGS shell 環境變數設定。 例如,我們可以將其添加到 Python 檔案的頂部

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
)

如需更多範例,另請參閱 Nvidia GPU 上 Pax 訓練建議的 XLA 旗標

程式碼產生旗標#

  • –xla_gpu_triton_gemm_any 對於任何支援的 GEMM (matmul),使用基於 Triton 的 GEMM 發射器。 預設值為 False。

通訊旗標#

  • –xla_gpu_enable_latency_hiding_scheduler 此旗標啟用延遲隱藏排程器,以有效率地將非同步通訊與計算重疊。 預設值為 False。

  • –xla_gpu_memory_limit_slop_factor 此旗標作為應用於總可用記憶體的乘數,建立一個閾值,引導延遲隱藏排程器 (LHS) 在記憶體減少和延遲隱藏最佳化之間取得平衡。 預設值為 95。

    此因子有效地為編譯器傳遞建立記憶體限制,決定排程器何時應優先考慮

    1. 記憶體減少:當記憶體使用量接近或超過計算的閾值時。

    2. 延遲隱藏:當記憶體使用量低於閾值時,允許更積極的最佳化,這些最佳化可能會暫時增加記憶體使用量,但能提高整體效能。

    透過調整此因子,使用者可以微調記憶體效率和效能最佳化之間的權衡。

  • –xla_gpu_enable_pipelined_collectives 當使用管線平行處理時,此旗標啟用將第 (i+1) 層權重 AllGather 與第 i 層計算重疊。 它還啟用將第 (i+1) 層權重 Reduce/ReduceScatter 與第 i 層的計算重疊。 預設值為 False。 當此旗標開啟時,存在一些錯誤。

  • –xla_gpu_collective_permute_decomposer_threshold 當執行 GSPMD 管線化時,此旗標很有用。 設定非零閾值會將 CollectivePermute 分解為 CollectivePermuteReceiveDoneCollectivePermuteSendDone 對,以便可以在每個對應的 ReceiveDone/SendDone 對之間執行計算,從而實現更多重疊。 預設情況下,閾值為 0,並且沒有分解。 將其設定為閾值 > 0,例如 --xla_gpu_collective_permute_decomposer_threshold=1024 可以啟用此功能。

  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 這些旗標調整何時將多個小的 AllGather/ReduceScatter/AllReduce 合併為一個大的 AllGather/ReduceScatter/AllReduce,以減少花費在跨裝置通訊上的時間。 例如,對於基於 Transformer 的工作負載上的 AllGather/ReduceScatter 閾值,請考慮將它們調整得足夠高,以便至少合併 Transformer 層的權重 AllGather/ReduceScatter。 預設情況下,combine_threshold_bytes 設定為 256。

NCCL 旗標#

這些 Nvidia NCCL 旗標值可能適用於 Nvidia GPU 上的單主機多裝置計算

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

這些 NCCL 旗標可以提高單主機通訊速度。 這些旗標似乎對多主機通訊尚無用處。

多進程 (Multi-Process)#

我們建議每個 GPU 使用一個進程,而不是每個節點一個進程。 在某些情況下,這可以加速 jitted 計算。jax.distributed.initialize() API 將在 SLURM 下執行時自動理解該配置。 但是,這只是一個經驗法則,在您的用例中測試每個 GPU 一個進程和每個節點一個進程可能都很有用。