XLA 編譯器旗標列表#

簡介#

本指南簡要概述 XLA 以及 XLA 與 Jax 的關係。如需深入瞭解,請參閱 XLA 文件。然後列出常用的 XLA 編譯器旗標,旨在最佳化 Jax 程式的效能。

XLA:Jax 背後的動力#

XLA (Accelerated Linear Algebra,加速線性代數) 是一種用於線性代數的特定領域編譯器,在 Jax 的效能和彈性方面扮演著關鍵角色。它使 Jax 能夠透過轉換和編譯您的 Python/NumPy 類型的程式碼為高效的機器指令,為各種硬體後端 (CPU、GPU、TPU) 產生最佳化的程式碼。

Jax 使用 XLA 的 JIT 編譯功能,在執行階段將您的 Python 函式轉換為最佳化的 XLA 計算。

在 Jax 中設定 XLA:#

您可以在執行 Python 腳本或 colab 筆記本之前,透過設定 XLA_FLAGS 環境變數來影響 XLA 在 Jax 中的行為。

對於 colab 筆記本

使用 os.environ['XLA_FLAGS'] 提供旗標

import os

# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'

對於 python 腳本

XLA_FLAGS 指定為 cli 命令的一部分

XLA_FLAGS='--flag1=value1 --flag2=value2'  python3 source.py

重要注意事項

  • 在匯入 Jax 或其他相關程式庫之前設定 XLA_FLAGS。在後端初始化之後變更 XLA_FLAGS 將不會有任何效果,並且由於後端初始化時間未明確定義,因此通常更安全的方式是在執行任何 Jax 程式碼之前設定 XLA_FLAGS

  • 嘗試不同的旗標,以最佳化您的特定使用案例的效能。

更多資訊

  • 關於 XLA 的完整且最新的文件,請參閱官方 XLA 文件

  • 對於開放原始碼版本的 XLA (CPU、GPU) 支援的後端,XLA 旗標及其預設值定義於 xla/debug_options_flags.cc 中,完整的旗標列表可以在這裡找到。

  • TPU 編譯器旗標不屬於 OpenXLA 的一部分,但常用的選項列於下方。

  • 請注意,此旗標列表並非詳盡無遺,並且可能會變更。這些旗標是實作細節,無法保證它們將保持可用或維持其目前的行為。

常見的 XLA 旗標#

旗標

類型

注意事項

xla_dump_to

字串 (檔案路徑)

將放置預先最佳化 HLO 檔案和其他artifacts的資料夾 (請參閱 XLA 工具)。

xla_enable_async_collective_permute

TristateFlag (true/false/auto)

將所有 collective-permute 操作重寫為其非同步變體。當設定為 auto 時,XLA 可以根據其他組態或條件自動開啟非同步 collective。

xla_enable_async_all_gather

TristateFlag (true/false/auto)

若設定為 true,則啟用非同步 all gather。若為 auto,則僅針對實作非同步 all-gather 的平台啟用。實作 (例如 BC-offload 或 continuation fusion) 根據其他旗標值選擇。

xla_disable_hlo_passes

字串 (逗號分隔的傳遞名稱列表)

要停用的 HLO 傳遞的逗號分隔列表。這些名稱必須完全符合傳遞名稱 (逗號周圍沒有空格)。

TPU XLA 旗標#

旗標

類型

注意事項

xla_tpu_enable_data_parallel_all_reduce_opt

布林值 (true/false)

最佳化以增加用於資料平行分片的 DCN (資料中心網路) all-reduce 的重疊機會。

xla_tpu_data_parallel_opt_different_sized_ops

布林值 (true/false)

即使資料平行操作的輸出大小與堆疊變數中可以就地儲存的大小不符,也能跨多個迭代啟用資料平行操作的管線化。可能會增加記憶體壓力。

xla_tpu_enable_async_collective_fusion

布林值 (true/false)

啟用將非同步 collective 通訊與計算操作 (輸出/迴圈融合或卷積) 融合的傳遞,這些計算操作排程在它們的 -start 和 -done 指令之間。

xla_tpu_enable_async_collective_fusion_fuse_all_gather

TristateFlag (true/false/auto)

啟用在 AsyncCollectiveFusion 傳遞中融合 all-gather。
若設定為 auto,則將根據目標啟用。

xla_tpu_enable_async_collective_fusion_multiple_steps

布林值 (true/false)

啟用在 AsyncCollectiveFusion 傳遞中的多個步驟 (融合) 中繼續相同的非同步 collective。

xla_tpu_overlap_compute_collective_tc

布林值 (true/false)

啟用單一 TensorCore 上計算和通訊的重疊,即一個核心相當於 MegaCore 融合。

xla_tpu_spmd_rng_bit_generator_unsafe

布林值 (true/false)

是否以分割方式執行 RngBitGenerator HLO,如果預期在計算的不同部分使用不同分片的確定性結果,則此方式是不安全的。

xla_tpu_megacore_fusion_allow_ags

布林值 (true/false)

允許將 all-gather 與卷積/all-reduce 融合。

xla_tpu_enable_ag_backward_pipelining

布林值 (true/false)

透過掃描迴圈向後管線化 all-gather (目前為 megascale all-gather)。

GPU XLA 旗標#

旗標

類型

注意事項

xla_gpu_enable_latency_hiding_scheduler

布林值 (true/false)

此旗標啟用延遲隱藏排程器,以有效率地將非同步通訊與計算重疊。預設值為 False。

xla_gpu_enable_triton_gemm

布林值 (true/false)

使用基於 Triton 的矩陣乘法。

xla_gpu_graph_level

旗標 (0-3)

用於設定 GPU 圖形層級的舊版旗標。在新使用案例中使用 xla_gpu_enable_command_buffer。0 = 關閉;1 = 捕獲融合和 memcopy;2 = 捕獲 gemm;3 = 捕獲卷積。

xla_gpu_all_reduce_combine_threshold_bytes

整數 (位元組)

這些旗標調整何時將多個小的 AllGather / ReduceScatter / AllReduce 組合為一個大的 AllGather / ReduceScatter / AllReduce,以減少花費在跨裝置通訊上的時間。例如,對於基於 Transformer 的工作負載上的 AllGather / ReduceScatter 閾值,請考慮將它們調整得足夠高,以便至少組合 Transformer 層的權重 AllGather / ReduceScatter。預設情況下,combine_threshold_bytes 設定為 256。

xla_gpu_all_gather_combine_threshold_bytes

整數 (位元組)

請參閱上方的 xla_gpu_all_reduce_combine_threshold_bytes。

xla_gpu_reduce_scatter_combine_threshold_bytes

整數 (位元組)

請參閱上方的 xla_gpu_all_reduce_combine_threshold_bytes。

xla_gpu_enable_pipelined_all_gather

布林值 (true/false)

啟用 all-gather 指令的管線化。

xla_gpu_enable_pipelined_reduce_scatter

布林值 (true/false)

啟用 reduce-scatter 指令的管線化。

xla_gpu_enable_pipelined_all_reduce

布林值 (true/false)

啟用 all-reduce 指令的管線化。

xla_gpu_enable_while_loop_double_buffering

布林值 (true/false)

啟用 while 迴圈的雙重緩衝。

xla_gpu_enable_all_gather_combine_by_dim

布林值 (true/false)

組合具有相同收集維度或不考慮其維度的 all-gather 操作。

xla_gpu_enable_reduce_scatter_combine_by_dim

布林值 (true/false)

組合具有相同維度或不考慮其維度的 reduce-scatter 操作。

延伸閱讀