複製誘導集合的有效轉置#
mattjj@, dougalm@
2023 年 8 月
動機#
我們在自動轉置包含特定集合的 shmap
時遇到效率問題。問題出在 psum
和 all_gather
,特別是當集合的輸出以未映射的輸出形式傳回給呼叫者時。這不是邊緣情況:例如,當將 grad
應用於基於 shmap
的批次資料平行神經網路損失函數時,該函數使用 psum
來計算總損失時,就會發生這種情況。
我們已經知道這個問題有一段時間了。與 pmap
存在類似的問題,儘管已透過將 grad
保留在 pmap
內部而不是外部來解決。未完成的 avals-with-names 工作的主要目標是解決此轉置效率問題的一個版本。本文檔借鑒了這些想法,同時擴展和修改了它們,以處理更多情況並更容易實現。實際上,此處提出的解決方案僅影響 shmap
實作。系統的其餘部分無需更改(目前)。
本文檔的主要目的是定義此轉置效率問題,並提出一個易於實現的解決方案。
本文檔不討論
陣列上的邏輯軸名稱(此處唯一的軸名稱就像
shmap
和 OGpmap
中的一樣);更改自動微分語義(所有數字和(非)錯誤都保持不變,我們只是讓事情更有效率);
允許使用者程式碼反映任何新資訊,或實際上完全影響使用者程式碼。
問題:psum
或 all_gather
的有效轉置取決於餘切是否在裝置之間不變#
考慮這個半真實的範例,旨在類似於複製參數批次資料平行損失函數
devices = jax.devices() # 8 devices
@partial(shmap, mesh=Mesh(devices, ('batch',)),
in_specs=(P(None, None), P('batch', None)),
out_specs=P())
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
global_loss = lax.pmean(local_loss, 'batch'))
return global_loss
請注意 out_specs=P()
,這表示未映射的輸出。如果您不熟悉未映射輸出的概念,請參閱本文檔底部的附錄。
loss
範例中的大多數細節都不重要。對我們而言,重要的是我們在最後應用 psum
(或者更確切地說是 pmean = lambda x, name: psum(x, name) / psum(1, name)
)。因此,精簡的版本如下所示
# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
我們甚至透過抑制 mesh
引數來簡化表示法。在後續的範例中,可以從上下文中推斷出來。
轉置是什麼樣子?寫入 t
表示函式轉置,我們可以透過應用下面的函式 ¿f1_transpose?
,有效率地評估任何 ybar
的 t(f1)(ybar)
# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))
但這不是我們目前作為 t(f1) 得到的轉置。
相反,目前的轉置方法大致是我們切換 in_specs
和 out_specs
,對未映射的輸出進行一些除法重新縮放,然後轉置主體。由於 psum
是它自己的轉置(作為 all-reduce 總和),我們最終產生了這個轉置
# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
in_specs=P(), out_specs=P('i'))
此轉置獲得了正確的數字,但很浪費。我們從轉置的 in_specs=P()
靜態地知道,對於每個函式實例,ybar
都具有相同的值,即其值對於沿著名為 i
的網格軸的裝置來說是不變的,但我們仍然對其應用 psum
!這使用了昂貴的通訊,只是為了將每個裝置上的值乘以 8。(此處的 8 是指軸 i 的大小。除以 8 來自原始函式的 out_specs=P()
;它和微不足道的 psum
基本上相互抵消。)
我們哪裡做錯了?我們沒有利用與 f1
的未映射輸出對應的餘切 ybar
保證是不變裝置的事實;相反,我們防禦性地 psum
它們,彷彿它們不是,因為給定它擁有的本地資訊,psum
的轉置無法確定。有時 psum
是必要的,例如在轉置 f2
相對於其第一個引數時
# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
直觀地說,如果我們的轉置機制可以區分範例 1 和範例 2,我們可以透過盡可能避免 psum 和除法來做得更好。
效率低下的範例甚至可以更小。考慮轉置這個被詛咒的恆等函式
# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())
# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...
我們轉置得越多,它就變得越大。多麼尷尬!
而且 psum
不是唯一的罪魁禍首。all_gather
也存在類似的情況
# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))
這個程式有點人工化。為什麼要執行 all_gather
並將結果饋送到未映射的輸出,而不是跳過主體中的 all_gather
,而只是使用 out_specs=P('i')
來收集結果?但即使它是虛構的,這個範例仍然展示了一個不必要地執行通訊的轉置(我們本來可以只執行非通訊切片),類似於 psum
的範例 1。
同樣類似於 psum
範例,在某些情況下,防禦性 psum_scatter
是必要的
# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
in_specs=(P('i'), P('i')), out_specs=P('i'))
那麼我們如何避免這些效率低下的轉置?
解決方案#
以下是兩個解決方案的想法。它們不是互斥的。但是(劇透)第二個更好,而且這就是我們所需要的。
部分解決方案「P-sum」:建構將 psum
表示為 out_specs
的能力#
這個解決方案有點稻草人,因為它只提供了一種笨拙的程式撰寫方式。而且它甚至無法解決所有問題!但值得考慮,即使只是為了激發更完整的解決方案。
上面的範例 4 是人工化的,因為我們可以只使用 out_specs
而不是主體中的 all_gather
# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))
f4_better
版本沒有任何轉置問題,因為轉置問題來自主體中的集合。
類似地,我們可以透過擴展 out_specs
來修復範例 1,以便它們可以表示求和
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis
# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))
因此,在 out_specs
中內建提供 psum
可以解決範例 1 的轉置問題。但它不能完全解決範例 3 中被詛咒的恆等轉置
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))
這是一個改進,因為程式不會隨著我們不斷轉置而繼續變大,但我們仍然在進行浪費的通訊。
完整解決方案:靜態追蹤裝置變動與裝置不變的中介值,加上新的基本運算#
此解決方案包含兩個組件
追蹤值何時保證在特定網格軸上是裝置不變的,而不是裝置變動的,以及
將
psum
分解為兩步驟流程,引入新的pbroadcast
基本運算,並為all_gather
及其轉置引入新的基本運算。
在道德上,追蹤裝置不變與裝置變動的資訊是類型層級的考量。但為了我們第一次實作的權宜之計,我們不需要真正將資訊添加到抽象值或 jaxpr 類型中。在我們開始實作之前,我們先使用類型介紹這個想法。
接下來還將討論如何讓使用者 API 方便且向後相容。但為了首先介紹這個想法,我們將忽略便利性,而是撰寫盡可能明確的程式碼。
在 avals 中追蹤裝置不變性(又名 avals-with-names,復活)#
有時我們可以僅從靜態資訊中得知,shmap
主體中某些中間變數的值保證沿著網格軸是不變的,從這個意義上說,沿著網格軸的函式實例(及其對應的裝置)都必須使用相同的值進行計算。我們將把這些值稱為裝置不變的。對於非裝置不變的值,我們將說它們是裝置變動的,儘管實際上我們指的是從類型系統的角度來看可能是裝置變動的。
為了在類型中編碼裝置變異性,我們將擴展陣列的類型語法。我們將撰寫類似 x:f32[3,4]{i}
的內容,以指示 x
沿著網格軸 i
(可能)是裝置變動的(並且在 shmap
的任何其他網格軸上是裝置不變的)。更一般地說,我們將說陣列類型語法的語法類似於
shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}
我們還將更新類型規則以處理裝置變異性類型
對於集合以外的一階基本運算
對於多元基本運算,運算元裝置變異性類型在形狀必須相等的地方必須相等,例如
mul x:f32[s1]{r1} y:f32[s2][r2]
除了s1 == s2
之外,還需要r1 == r2
輸出裝置變異性類型必須與運算元相同
對於高階基本運算
我們只是實例化任何類型變數,包括裝置變異性類型(並且檢查類型的相等性會檢查其裝置變異性類型是否相等)
(當執行類型推斷時,例如對於
cond
的分支,我們取裝置變異性類型中軸名稱集合的聯集)
對於一階集合
集合可以接受裝置變動或裝置不變的輸入(沿著與其軸名稱參數對應的網格軸);將裝置不變的運算元傳遞給接受裝置變動運算元的集合,反之亦然,都是錯誤的
集合可以產生裝置變動或裝置不變的輸出
請參閱下表。作為一個附帶好處,無論什麼邏輯實作此類型檢查,都可以取代
shmap
的「靜態分析」檢查,以檢查shmap
主體函式是否與任何未映射的out_specs
相容。
下表總結了集合基本運算的裝置變異性類型
名稱 |
裝置變異性類型 |
範例 |
降低到 HLO |
轉置 |
---|---|---|---|---|
|
|
|
|
|
|
|
|
無運算(無通訊) |
|
|
|
|
|
|
|
|
|
|
不適用 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
這裡有一些令人驚訝的事情!
我們介紹了幾個新的基本運算,包括
pbroadcast
,有趣的是,它會降低為無運算all_gather_invariant
,它降低到與all_gather
相同的東西,但具有不同的裝置變異性類型(基本上all_gather
具有融合在其中的pbroadcast
,而all_gather_invariant
則沒有)pscatter
,它是all_gather_invariant
的對偶(轉置)
all_gather 具有裝置變動的結果
直觀地說,引入 pbroadcast
(除了使類型規則生效之外)的原因是為了讓 psum
可以轉置為物理上的空操作 (no-op)。我們需要 all_gather
具有裝置變異 (device-varying) 結果的原因是為了讓我們可以將其轉置為 psum_scatter
;如果我們將其結果保持為裝置不變 (device-invariant),我們可能需要下游的 pbroadcast
,而這種組合將轉置為效率低下的 psum
,然後進行切片 / pscatter
。因此,我們將 pbroadcast
「融合到」all_gather
中,從而實現有效率地轉置為 psum_scatter
。我們提供 all_gather_invariant
及其轉置 pscatter
主要是為了完整性;使用者不太可能需要它(它對應於範例 4 中的情況,該情況很容易使用 out_specs
以不同的方式編寫)。
有趣的是,psum
和 pbroadcast
轉置對應於使用者在使用 pmap
訓練 LLM 時引入的 psum_idrev
和 id_psumrev
。
此系統如何解決效率低下的轉置範例#
再次考慮簡化的動機範例
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
w:f32[]{i} = g(x)
y:f32[]{} = psum(w, 'i')
return y
有了這些新規則,轉置是
# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
in_specs=P(), out_specs=P('i'))
# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
wbar:f32[]{i} = pbroadcast(ybar, 'i')
xbar:f32[3,4]{i} = transpose(g)(wbar)
return xbar
其中評估 pbroadcast
應用程式根本不涉及任何通訊或 FLOP;它是一個空操作 (no-op)。請注意,如果我們持續轉置,主體的大小不會增加;實際上 t(t(f1)) == f1
。效率提升了!
只要我們在需要的地方使用 pbroadcast
來檢查類型,我們也不會搞砸其他範例
# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.
# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())
直觀地說,在範例 1 中,我們現在只有「原始 psum 的一半」,而在範例 2 中,我們得到「兩半」。對於範例 3,我們根本不需要主體中的任何操作。
對於 all_gather
範例,範例 4 需要使用 all_reduce_invariant
才能進行有效率的轉置(儘管最好使用 out_specs
而不是主體中的 collective)
# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())
# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
y:f32[8]{} = all_gather_invariant(x, 'i')
return y
# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
xbar:f32[1]{i} = pscatter(ybar, 'i')
return xbar
對於範例 5,使用裝置變異 (device-varying) 的 all_gather
可以如我們所願地運作
# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
z:f32[8]{i} = all_gather(x, 'i')
w:f32[8]{i} = z * y
return w
# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
zbar:f32[8]{i} = wbar * y
xbar:f32[1]{i} = psum_scatter(zbar, 'i')
return xbar
如何使 API 對使用者來說更方便(並且向後相容)#
但是,哪個使用者想要編寫 pbroadcast
?又有哪個開發人員想要破壞大量現有的使用者程式碼,這些程式碼涉及未饋送到未映射輸出 (unmapped output) 的 psum
?我不想!
相反,我們可以自動插入 pbroadcast
。這有點類似於我們在 jax.numpy
層執行自動階級提升 (rank promotion) 的方式,插入 broadcast 以避免二元運算符中的階級不匹配錯誤。但是,這要簡單得多,因為我們不需要處理形狀元組。典型的規則是:每當我們看到一個多元運算,其中運算元的裝置變異類型 (device variance type) 不一致時,取運算元裝置變異類型 (device variance type) 的軸名稱集合 (axis name set) 的聯集,並插入 pbroadcast
以將每個運算元提升到產生的裝置變異類型 (device variance type)。
在需要 pbroadcast
之前自動插入它們,可能意味著我們將相同的 pbroadcast
多次應用於相同的運算元,從而建立常見的子表達式。當我們轉置時,這些可能會變成 sum-of-psum
而不是 psum
-of-sum。我們將依靠編譯器來適當地清理它。如果這是一個問題,那麼我們可以向 pbroadcast
插入通道添加一些簡單的記憶化 (memoization)。
用於 all_gather
的使用者 API 預設將表示 all_gather_p
(而不是 all_gather_invariant_p
),涵蓋常見情況,並且意味著不必插入 pbroadcast
。
我們可以在 shmap
上提供一個選項來停用 pbroadcast
的自動插入,在這種情況下,將由使用者來確保類型正確性。對於一些想要明確指出 psum
在反向傳遞中發生位置的人來說,這個明確的選項可能很有吸引力。
如何實作解決方案#
使實作輕量化的關鍵在於我們不會將這些類型添加到 avals 或 jaxprs 中。至少,一開始不會。這可能會很耗費資源,因為它需要更新 JAX 的其餘部分,例如,avals 和 jaxprs 的所有消費者可能都需要處理新類型。我們不會再犯同樣的錯誤!
相反,我們將把這些擴展類型保留為 shmap
內部的元數據,就像當前「用於 out_specs
的複製檢查」機制在 shmap
內部一樣。實際上,此解決方案相當於對現有機制進行相對較小的擴展:它已經在追蹤相同的資訊;現在我們只是添加了 pbroadcast
。
我們至少有兩種選擇可以在哪裡執行 pbroadcast
插入
就在轉置之前,在轉置規則中,我們有一個要轉置的計算的 jaxpr;
在每個
shmap
主體中,無論是立即執行還是階段性輸出,都像當前「用於out_specs
的複製檢查」機制一樣。前者可能最終會更容易,因為我們只需要處理 jaxpr 的情況,而且只需要線性原語。但我們將首先嘗試後者,因此此處的實作是對現有複製檢查邏輯的嚴格修訂/擴展。
附錄:定義和激勵具有未映射輸入和輸出的映射#
為了具體起見,我們將主要關注 shmap
,儘管這些相同的想法適用於例如 pmap
,甚至可能是 xmap
。
當 in_specs
的對應條目沒有提及該網格軸的名稱時,參數/輸入沿網格軸是未映射的 (unmapped)。從邏輯上講,這意味著沿該網格軸的每個函數實例都獲得相同的參數值。對於呼叫者來說,每個運算元都根據運算元映射到的網格軸進行切片,而對於運算元未映射到的網格軸則不進行切片。
當 out_specs
的對應條目沒有提及該網格軸的名稱時,輸出沿網格軸是未映射的 (unmapped)。從邏輯上講,這意味著沿該網格軸的每個函數實例都必須傳回相同的值。對於呼叫者來說,shmap
的每個結果都是通過串連每個函數實例的傳回值形成的,而輸出映射到這些函數實例,而對於輸出未映射到的網格軸,僅使用該值的一個副本。
有關未映射輸入和輸出的範例,請參閱 shmap
JEP。為了比較,在 vmap
中,未映射的輸入/輸出通過使用 in_axes
/ out_axes
的 None
(而不是 int
)來指示。
以下是我們喜歡 shmap
的未映射輸入和輸出的原因
與
pjit
相同的表達能力。pjit
可以做的任何事情,shmap
的應急方案也應該能夠做到。否則我們的應急方案就會有所欠缺!如果我們在shmap
中沒有未映射的輸出,那麼我們就無法表達與pjit
相同的批次並行損失函數計算。閉包輸入 (Closed-over input)。 閉包輸入 (Closed-over input) 本質上對應於未映射的輸入,並且…
轉置下的閉包 (Closure under transposition)。 一旦我們有了未映射的輸入,自然就可以轉置為未映射的輸出。
因此,未映射的輸出既是規範的,又是有用的!