GPU 效能提示#
本文檔著重於神經網路工作負載的效能提示
Matmul 精確度#
在最新的 GPU 世代,例如 Nvidia A100 世代或更新的世代,以 bfloat16
精確度執行大多數計算可能是個好主意。 例如,如果使用 Flax,請使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16)
實例化 Dense
層。以下是一些程式碼範例
在 Flax LM1B 範例中,
Dense
模組以可配置的 dtype 實例化,該 dtype 預設為 bfloat16。在 MaxText 中,
DenseGeneral
模組也以可配置的 dtype 實例化,該 dtype 預設為 bfloat16。
XLA 效能旗標#
注意
JAX-Toolbox 也有關於 NVIDIA XLA 效能旗標的頁面。
XLA 旗標的存在和確切行為可能與 jaxlib
版本有關。
截至 jaxlib==0.4.18
(於 2023 年 10 月 6 日發布),設定這些 XLA 旗標可以提高效能。 其中一些與 GPU 之間的通訊有關,因此僅在多個裝置上執行計算時才相關,而另一些則與每個裝置上的程式碼產生有關。
其中一些旗標可能在未來的版本中預設設定。
這些旗標可以透過 XLA_FLAGS
shell 環境變數設定。 例如,我們可以將其添加到 Python 檔案的頂部
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_latency_hiding_scheduler=true '
)
如需更多範例,另請參閱 Nvidia GPU 上 Pax 訓練建議的 XLA 旗標。
程式碼產生旗標#
–xla_gpu_triton_gemm_any 對於任何支援的 GEMM (matmul),使用基於 Triton 的 GEMM 發射器。 預設值為 False。
通訊旗標#
–xla_gpu_enable_latency_hiding_scheduler 此旗標啟用延遲隱藏排程器,以有效率地將非同步通訊與計算重疊。 預設值為 False。
–xla_gpu_memory_limit_slop_factor 此旗標作為應用於總可用記憶體的乘數,建立一個閾值,引導延遲隱藏排程器 (LHS) 在記憶體減少和延遲隱藏最佳化之間取得平衡。 預設值為 95。
此因子有效地為編譯器傳遞建立記憶體限制,決定排程器何時應優先考慮
記憶體減少:當記憶體使用量接近或超過計算的閾值時。
延遲隱藏:當記憶體使用量低於閾值時,允許更積極的最佳化,這些最佳化可能會暫時增加記憶體使用量,但能提高整體效能。
透過調整此因子,使用者可以微調記憶體效率和效能最佳化之間的權衡。
–xla_gpu_enable_pipelined_collectives 當使用管線平行處理時,此旗標啟用將第 (i+1) 層權重
AllGather
與第 i 層計算重疊。 它還啟用將第 (i+1) 層權重Reduce
/ReduceScatter
與第 i 層的計算重疊。 預設值為 False。 當此旗標開啟時,存在一些錯誤。–xla_gpu_collective_permute_decomposer_threshold 當執行 GSPMD 管線化時,此旗標很有用。 設定非零閾值會將
CollectivePermute
分解為CollectivePermuteReceiveDone
和CollectivePermuteSendDone
對,以便可以在每個對應的ReceiveDone
/SendDone
對之間執行計算,從而實現更多重疊。 預設情況下,閾值為 0,並且沒有分解。 將其設定為閾值 > 0,例如--xla_gpu_collective_permute_decomposer_threshold=1024
可以啟用此功能。–xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 這些旗標調整何時將多個小的
AllGather
/ReduceScatter
/AllReduce
合併為一個大的AllGather
/ReduceScatter
/AllReduce
,以減少花費在跨裝置通訊上的時間。 例如,對於基於 Transformer 的工作負載上的AllGather
/ReduceScatter
閾值,請考慮將它們調整得足夠高,以便至少合併 Transformer 層的權重AllGather
/ReduceScatter
。 預設情況下,combine_threshold_bytes
設定為 256。
NCCL 旗標#
這些 Nvidia NCCL 旗標值可能適用於 Nvidia GPU 上的單主機多裝置計算
os.environ.update({
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
})
這些 NCCL 旗標可以提高單主機通訊速度。 這些旗標似乎對多主機通訊尚無用處。
多進程 (Multi-Process)#
我們建議每個 GPU 使用一個進程,而不是每個節點一個進程。 在某些情況下,這可以加速 jitted 計算。jax.distributed.initialize()
API 將在 SLURM 下執行時自動理解該配置。 但是,這只是一個經驗法則,在您的用例中測試每個 GPU 一個進程和每個節點一個進程可能都很有用。