GPU 記憶體配置#

JAX 會在第一次執行 JAX 操作時預先配置 75% 的 GPU 總記憶體。 預先配置可最大程度地減少配置開銷和記憶體碎片,但有時可能會導致記憶體不足 (OOM) 錯誤。如果您的 JAX 程序因 OOM 而失敗,則可以使用以下環境變數來覆寫預設行為

XLA_PYTHON_CLIENT_PREALLOCATE=false

這會停用預先配置行為。JAX 將改為根據需要配置 GPU 記憶體,從而可能減少整體記憶體使用量。但是,此行為更容易發生 GPU 記憶體碎片,這表示使用大部分可用 GPU 記憶體的 JAX 程式可能會在停用預先配置的情況下發生 OOM。

XLA_PYTHON_CLIENT_MEM_FRACTION=.XX

如果啟用預先配置,這會使 JAX 預先配置 XX% 的 GPU 總記憶體,而不是預設的 75%。降低預先配置量可以修正 JAX 程式啟動時發生的 OOM。

XLA_PYTHON_CLIENT_ALLOCATOR=platform

這會使 JAX 完全依照需求配置,並釋放不再需要的記憶體(請注意,這是唯一會釋放 GPU 記憶體而不是重複使用的配置)。這非常慢,因此不建議一般使用,但可能適用於以最小可能的 GPU 記憶體佔用量執行或偵錯 OOM 失敗。

OOM 失敗的常見原因#

同時執行多個 JAX 程序。

使用 XLA_PYTHON_CLIENT_MEM_FRACTION 為每個程序提供適當的記憶體量,或設定 XLA_PYTHON_CLIENT_PREALLOCATE=false

同時執行 JAX 和 GPU TensorFlow。

TensorFlow 預設也會預先配置,因此這與同時執行多個 JAX 程序類似。

一種解決方案是使用僅限 CPU 的 TensorFlow(例如,如果您僅使用 TF 進行資料載入)。您可以使用命令 tf.config.experimental.set_visible_devices([], "GPU") 防止 TensorFlow 使用 GPU

或者,使用 XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATE。還有類似的選項可以配置 TensorFlow 的 GPU 記憶體配置(TF1 中的 gpu_memory_fractionallow_growth,應在傳遞至 tf.Sessiontf.ConfigProto 中設定。請參閱 使用 GPU:限制 GPU 記憶體成長 以取得 TF2)。

在顯示 GPU 上執行 JAX。

使用 XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATE

停用重新實體化 HLO 傳遞

有時停用自動重新實體化 HLO 傳遞有利於避免編譯器做出不良的 remat 選擇。可以透過設定 jax.config.update('enable_remat_opt_pass', True)jax.config.update('enable_remat_opt_pass', False) 分別啟用/停用傳遞。啟用或停用自動 remat 傳遞會在運算和記憶體之間產生不同的權衡。但請注意,該演算法是基本的,並且通常可以透過停用自動 remat 傳遞並使用 jax.remat API 手動執行來獲得更好的運算和記憶體權衡

實驗性功能#

此處的功能是實驗性的,必須謹慎嘗試。

TF_GPU_ALLOCATOR=cuda_malloc_async

這會將 XLA 自己的 BFC 記憶體配置器替換為 cudaMallocAsync。這將移除大的固定預先配置,並使用成長型記憶體池。預期的好處是不需要設定 XLA_PYTHON_CLIENT_MEM_FRACTION

風險是

  • 記憶體碎片不同,因此如果您接近限制,則由於碎片造成的確切 OOM 情況將會不同。

  • 配置時間不會在開始時全部支付,而是在需要增加記憶體池時產生。因此,您可能會在開始時遇到較小的速度穩定性,並且對於基準測試,忽略前幾次迭代將更加重要。

可以透過預先配置大量區塊來減輕風險,並且仍然可以獲得成長型記憶體池的好處。這可以使用 TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=N 完成。如果 N 為 -1,它將預先配置與預設配置相同的量。否則,它是您要預先配置的大小(以位元組為單位)。