持續編譯快取#
JAX 針對已編譯的程式具有可選的磁碟快取。如果啟用,JAX 將在磁碟上儲存已編譯程式的副本,這可以在重複執行相同或相似任務時節省重新編譯時間。
注意:如果編譯快取不在本機檔案系統上,則需要安裝 etils。
pip install etils
用法#
快速開始#
import jax
import jax.numpy as jnp
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir")
@jax.jit
def f(x):
return x + 1
x = jnp.zeros((2, 2))
f(x)
設定快取目錄#
當設定了快取位置時,就會啟用編譯快取。這應該在第一次編譯之前完成。依照以下方式設定位置
(1) 使用環境變數
在 shell 中,執行腳本之前
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"
或在 Python 腳本的頂部
import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
(2) 使用 jax.config.update()
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
(3) 使用 set_cache_dir()
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("/tmp/jax_cache")
快取閾值#
jax_persistent_cache_min_compile_time_secs
:只有當編譯時間長於指定值時,才會將計算寫入持久快取。預設值為 1.0 秒。jax_persistent_cache_min_entry_size_bytes
:將在持續編譯快取中快取的最小條目大小(以位元組為單位)-1
:停用大小限制並防止覆寫。保持預設值(
0
)以允許覆寫。覆寫通常會確保最小大小對於用於快取的檔案系統來說是最佳的。> 0
:所需的實際最小大小;不覆寫。
請注意,要快取函數,需要滿足這兩個條件。
額外快取#
XLA 支援額外的快取機制,可以與 JAX 的持續編譯快取一起啟用,以進一步縮短重新編譯時間。
jax_persistent_cache_enable_xla_caches
:可能的值all
:啟用所有 XLA 快取功能none
:不啟用任何額外的 XLA 快取功能xla_gpu_kernel_cache_file
:僅啟用核心快取xla_gpu_per_fusion_autotune_cache_dir
:(預設值)僅啟用自動調整快取
Google Cloud#
在 Google Cloud 上執行時,編譯快取可以放置在 Google Cloud Storage (GCS) 儲存桶上。我們建議以下配置
在與工作負載將要執行的區域相同的區域中建立儲存桶。
在與工作負載 VM 相同的專案中建立儲存桶。確保已設定權限,以便 VM 可以寫入儲存桶。
對於較小的工作負載,不需要複製。較大的工作負載可能會受益於複製。
對於儲存桶的預設儲存類別,請使用「標準」。
將軟刪除政策設定為最短:7 天。
將物件生命週期設定為工作負載運行的預期持續時間。例如,如果預期工作負載運行 10 天,則將物件生命週期設定為 10 天。這應該涵蓋整個運行期間發生的重新啟動。針對生命週期條件使用
age
,針對動作使用Delete
。請參閱物件生命週期管理以取得詳細資訊。如果未設定物件生命週期,快取將會持續增長,因為沒有實作驅逐機制。支援所有加密政策。
假設 gs://jax-cache
是 GCS 儲存桶,請依照以下方式設定快取位置
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
運作方式#
快取金鑰是已編譯函數的簽章,包含以下參數
函數執行的計算,由正在雜湊的 JAX 函數的非優化 HLO 擷取
jaxlib 版本
相關的 XLA 編譯旗標
裝置配置通常由裝置數量和裝置拓撲擷取。目前對於 GPU,拓撲僅包含 GPU 名稱的字串表示形式
用於壓縮已編譯可執行檔的壓縮演算法
由
jax._src.cache_key.custom_hook()
產生的字串。可以將此函數重新指派給使用者定義的函數,以便可以更改產生的字串。預設情況下,此函數始終返回空字串。
多節點快取#
程式第一次運行時(持久快取是冷的/空的),所有進程都將編譯,但只有全域通訊群組中排名為 0 的進程才會寫入持久快取。在後續運行中,所有進程都將嘗試從持久快取中讀取,因此持久快取位於共用檔案系統(例如:NFS)或遠端儲存(例如:GFS)中非常重要。如果持久快取是本機的排名 0,則除了排名 0 之外的所有進程將再次在後續運行中編譯,這是由於編譯快取未命中造成的。
記錄快取活動#
檢查持久編譯快取到底發生了什麼對於除錯很有幫助。以下是一些關於如何開始的建議。
使用者可以透過放置以下內容來啟用相關原始碼檔案的記錄
import os
os.environ["JAX_DEBUG_LOG_MODULES"] = "jax._src.compiler,jax._src.lru_cache"
在腳本的頂部。或者,您可以使用以下內容更改全域 jax 記錄層級
import os
os.environ["JAX_LOGGING_LEVEL"] = "DEBUG"
# or locally with
jax.config.update("jax_logging_level", "DEBUG")
檢查快取未命中#
為了檢查和理解為何會發生快取未命中,JAX 包含一個配置旗標,該旗標啟用記錄所有快取未命中(包括持續編譯快取未命中)及其說明。雖然目前,這僅針對追蹤快取未命中實作,但最終目標是解釋所有快取未命中。可以透過設定以下配置來啟用此功能。
jax.config.update("jax_explain_cache_misses", True)
陷阱#
目前已發現一些陷阱
目前,持久快取不適用於具有主機回呼的函數。在這種情況下,完全避免快取。
這是因為 HLO 包含指向回呼的指標,即使計算和計算基礎架構完全相同,指標也會在每次運行之間發生變化。
目前,持久快取不適用於使用實作其自身 custom_partitioning 的基本運算元的函數。
函數的 HLO 包含指向 custom_partitioning 回呼的指標,並且對於跨運行的相同計算會產生不同的快取金鑰。
在這種情況下,快取仍然會繼續進行,但每次都會產生不同的金鑰,從而使快取失效。
解決 custom_partitioning
的方法#
如前所述,編譯快取不適用於由實作 custom_partitioning
的基本運算元組成的函數。但是,可以使用 shard_map 來規避那些實作它的基本運算元的 custom_partitioning
,並使編譯快取按預期工作
假設我們有一個函數 F
,它實作了 layernorm,然後使用實作 custom_partitioning
的基本運算元 LayerNorm
進行矩陣乘法
import jax
def F(x1, x2, gamma, beta):
ln_out = LayerNorm(x1, gamma, beta)
return ln_out @ x2
如果我們只是在沒有 shard_map 的情況下編譯此函數,則每次運行相同程式碼時,layernorm_matmul_without_shard_map
的快取金鑰都會不同
layernorm_matmul_without_shard_map = jax.jit(F, in_shardings=(...), out_sharding=(...))(x1, x2, gamma, beta)
但是,如果我們將 layernorm 基本運算元包裝在 shard_map 中,並定義一個執行相同計算的函數 G,則儘管 LayerNorm
實作了 custom_partitioning
,layernorm_matmul_with_shard_map
的快取金鑰每次都會相同
import jax
from jax.experimental.shard_map import shard_map
def G(x1, x2, gamma, beta, mesh, ispecs, ospecs):
ln_out = shard_map(LayerNorm, mesh, in_specs=ispecs, out_specs=ospecs, check_rep=False)(x1, x2, gamma, beta)
return ln_out @ x2
ispecs = jax.sharding.PartitionSpec(...)
ospecs = jax.sharding.PartitionSpec(...)
mesh = jax.sharding.Mesh(...)
layernorm_matmul_with_shard_map = jax.jit(G, static_argnames=['mesh', 'ispecs', 'ospecs'])(x1, x2, gamma, beta, mesh, ispecs, ospecs)
請注意,實作 custom_partitioning
的基本運算元必須包裝在 shard_map 中才能解決此問題。將外部函數 F
包裝在 shard_map 中是不夠的。