複製誘導集合的有效轉置#

mattjj@, dougalm@

2023 年 8 月

動機#

我們在自動轉置包含特定集合的 shmap 時遇到效率問題。問題出在 psumall_gather,特別是當集合的輸出以未映射的輸出形式傳回給呼叫者時。這不是邊緣情況:例如,當將 grad 應用於基於 shmap 的批次資料平行神經網路損失函數時,該函數使用 psum 來計算總損失時,就會發生這種情況。

我們已經知道這個問題有一段時間了。與 pmap 存在類似的問題,儘管已透過將 grad 保留在 pmap 內部而不是外部來解決。未完成的 avals-with-names 工作的主要目標是解決此轉置效率問題的一個版本。本文檔借鑒了這些想法,同時擴展和修改了它們,以處理更多情況並更容易實現。實際上,此處提出的解決方案僅影響 shmap 實作。系統的其餘部分無需更改(目前)。

本文檔的主要目的是定義此轉置效率問題,並提出一個易於實現的解決方案。

本文檔不討論

  • 陣列上的邏輯軸名稱(此處唯一的軸名稱就像 shmap 和 OG pmap 中的一樣);

  • 更改自動微分語義(所有數字和(非)錯誤都保持不變,我們只是讓事情更有效率);

  • 允許使用者程式碼反映任何新資訊,或實際上完全影響使用者程式碼。

問題:psumall_gather 的有效轉置取決於餘切是否在裝置之間不變#

考慮這個半真實的範例,旨在類似於複製參數批次資料平行損失函數

devices = jax.devices()  # 8 devices

@partial(shmap, mesh=Mesh(devices, ('batch',)),
         in_specs=(P(None, None), P('batch', None)),
         out_specs=P())
def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
  global_loss = lax.pmean(local_loss, 'batch'))
  return global_loss

請注意 out_specs=P(),這表示未映射的輸出。如果您不熟悉未映射輸出的概念,請參閱本文檔底部的附錄。

loss 範例中的大多數細節都不重要。對我們而言,重要的是我們在最後應用 psum(或者更確切地說是 pmean = lambda x, name: psum(x, name) / psum(1, name))。因此,精簡的版本如下所示

# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

我們甚至透過抑制 mesh 引數來簡化表示法。在後續的範例中,可以從上下文中推斷出來。

轉置是什麼樣子?寫入 t 表示函式轉置,我們可以透過應用下面的函式 ¿f1_transpose?,有效率地評估任何 ybart(f1)(ybar)

# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))

但這不是我們目前作為 t(f1) 得到的轉置。

相反,目前的轉置方法大致是我們切換 in_specsout_specs,對未映射的輸出進行一些除法重新縮放,然後轉置主體。由於 psum 是它自己的轉置(作為 all-reduce 總和),我們最終產生了這個轉置

# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
              in_specs=P(), out_specs=P('i'))

此轉置獲得了正確的數字,但很浪費。我們從轉置的 in_specs=P() 靜態地知道,對於每個函式實例,ybar 都具有相同的值,即其值對於沿著名為 i 的網格軸的裝置來說是不變的,但我們仍然對其應用 psum!這使用了昂貴的通訊,只是為了將每個裝置上的值乘以 8。(此處的 8 是指軸 i 的大小。除以 8 來自原始函式的 out_specs=P();它和微不足道的 psum 基本上相互抵消。)

我們哪裡做錯了?我們沒有利用與 f1 的未映射輸出對應的餘切 ybar 保證是不變裝置的事實;相反,我們防禦性地 psum 它們,彷彿它們不是,因為給定它擁有的本地資訊,psum 的轉置無法確定。有時 psum 是必要的,例如在轉置 f2 相對於其第一個引數時

# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
          in_specs=(P('i'), P('i')), out_specs=P('i'))

# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
                in_specs=(P('i'), P('i')), out_specs=P('i'))

直觀地說,如果我們的轉置機制可以區分範例 1 和範例 2,我們可以透過盡可能避免 psum 和除法來做得更好。

效率低下的範例甚至可以更小。考慮轉置這個被詛咒的恆等函式

# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())

# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...

我們轉置得越多,它就變得越大。多麼尷尬!

而且 psum 不是唯一的罪魁禍首。all_gather 也存在類似的情況

# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))

這個程式有點人工化。為什麼要執行 all_gather 並將結果饋送到未映射的輸出,而不是跳過主體中的 all_gather,而只是使用 out_specs=P('i') 來收集結果?但即使它是虛構的,這個範例仍然展示了一個不必要地執行通訊的轉置(我們本來可以只執行非通訊切片),類似於 psum 的範例 1。

同樣類似於 psum 範例,在某些情況下,防禦性 psum_scatter 是必要的

# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))

那麼我們如何避免這些效率低下的轉置?

解決方案#

以下是兩個解決方案的想法。它們不是互斥的。但是(劇透)第二個更好,而且這就是我們所需要的。

部分解決方案「P-sum」:建構將 psum 表示為 out_specs 的能力#

這個解決方案有點稻草人,因為它只提供了一種笨拙的程式撰寫方式。而且它甚至無法解決所有問題!但值得考慮,即使只是為了激發更完整的解決方案。

上面的範例 4 是人工化的,因為我們可以只使用 out_specs 而不是主體中的 all_gather

# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())

# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))

f4_better 版本沒有任何轉置問題,因為轉置問題來自主體中的集合。

類似地,我們可以透過擴展 out_specs 來修復範例 1,以便它們可以表示求和

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i'))  # sum='i' means sum over that axis

# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))

因此,在 out_specs 中內建提供 psum 可以解決範例 1 的轉置問題。但它不能完全解決範例 3 中被詛咒的恆等轉置

# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())

# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))

這是一個改進,因為程式不會隨著我們不斷轉置而繼續變大,但我們仍然在進行浪費的通訊。

完整解決方案:靜態追蹤裝置變動與裝置不變的中介值,加上新的基本運算#

此解決方案包含兩個組件

  1. 追蹤值何時保證在特定網格軸上是裝置不變的,而不是裝置變動的,以及

  2. psum 分解為兩步驟流程,引入新的 pbroadcast 基本運算,並為 all_gather 及其轉置引入新的基本運算。

在道德上,追蹤裝置不變與裝置變動的資訊是類型層級的考量。但為了我們第一次實作的權宜之計,我們不需要真正將資訊添加到抽象值或 jaxpr 類型中。在我們開始實作之前,我們先使用類型介紹這個想法。

接下來還將討論如何讓使用者 API 方便且向後相容。但為了首先介紹這個想法,我們將忽略便利性,而是撰寫盡可能明確的程式碼。

在 avals 中追蹤裝置不變性(又名 avals-with-names,復活)#

有時我們可以僅從靜態資訊中得知,shmap 主體中某些中間變數的值保證沿著網格軸是不變的,從這個意義上說,沿著網格軸的函式實例(及其對應的裝置)都必須使用相同的值進行計算。我們將把這些值稱為裝置不變的。對於非裝置不變的值,我們將說它們是裝置變動的,儘管實際上我們指的是從類型系統的角度來看可能是裝置變動的。

為了在類型中編碼裝置變異性,我們將擴展陣列的類型語法。我們將撰寫類似 x:f32[3,4]{i} 的內容,以指示 x 沿著網格軸 i(可能)是裝置變動的(並且在 shmap 的任何其他網格軸上是裝置不變的)。更一般地說,我們將說陣列類型語法的語法類似於

shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}

我們還將更新類型規則以處理裝置變異性類型

  • 對於集合以外的一階基本運算

    • 對於多元基本運算,運算元裝置變異性類型在形狀必須相等的地方必須相等,例如 mul x:f32[s1]{r1} y:f32[s2][r2] 除了 s1 == s2 之外,還需要 r1 == r2

    • 輸出裝置變異性類型必須與運算元相同

  • 對於高階基本運算

    • 我們只是實例化任何類型變數,包括裝置變異性類型(並且檢查類型的相等性會檢查其裝置變異性類型是否相等)

    • (當執行類型推斷時,例如對於 cond 的分支,我們取裝置變異性類型中軸名稱集合的聯集)

  • 對於一階集合

    • 集合可以接受裝置變動或裝置不變的輸入(沿著與其軸名稱參數對應的網格軸);將裝置不變的運算元傳遞給接受裝置變動運算元的集合,反之亦然,都是錯誤的

    • 集合可以產生裝置變動或裝置不變的輸出

    • 請參閱下表。作為一個附帶好處,無論什麼邏輯實作此類型檢查,都可以取代 shmap 的「靜態分析」檢查,以檢查 shmap 主體函式是否與任何未映射的 out_specs 相容。

下表總結了集合基本運算的裝置變異性類型

名稱

裝置變異性類型

範例

降低到 HLO

轉置

psum2

變動 -> 不變

y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')

AllReduceSum(通訊)

pbroadcast

pbroadcast

不變 -> 變動

y:f32[3]{i} = pbroadcast(x:f32[3], 'i')

無運算(無通訊)

psum

all_to_all

變動 -> 變動

y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0) AllToAll(通訊)

all_to_all

axis_index

() -> 變動

idx:i32[]{i} = axis_index('i')

ReplicaId 和一些算術(無通訊)

不適用

psum_scatter

變動 -> 變動

y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')

ReduceScatterSum(通訊)

all_gather

all_gather

變動 -> 變動

y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')

AllGather(通訊)

psum_scatter

pscatter

不變 -> 變動

y:f32[2]{i} = pscatter(x:f32[16], 'i')

lambda x: x[axis_index('i'), None](無通訊)

all_gather_invariant

all_gather_invariant

變動 -> 不變

y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')

AllGather(通訊)

pscatter

這裡有一些令人驚訝的事情!

  • 我們介紹了幾個新的基本運算,包括

    • pbroadcast,有趣的是,它會降低為無運算

    • all_gather_invariant,它降低到與 all_gather 相同的東西,但具有不同的裝置變異性類型(基本上 all_gather 具有融合在其中的 pbroadcast,而 all_gather_invariant 則沒有)

    • pscatter,它是 all_gather_invariant 的對偶(轉置)

  • all_gather 具有裝置變動的結果

直觀地說,引入 pbroadcast(除了使類型規則生效之外)的原因是為了讓 psum 可以轉置為物理上的空操作 (no-op)。我們需要 all_gather 具有裝置變異 (device-varying) 結果的原因是為了讓我們可以將其轉置為 psum_scatter;如果我們將其結果保持為裝置不變 (device-invariant),我們可能需要下游的 pbroadcast,而這種組合將轉置為效率低下的 psum,然後進行切片 / pscatter。因此,我們將 pbroadcast 「融合到」all_gather 中,從而實現有效率地轉置為 psum_scatter。我們提供 all_gather_invariant 及其轉置 pscatter 主要是為了完整性;使用者不太可能需要它(它對應於範例 4 中的情況,該情況很容易使用 out_specs 以不同的方式編寫)。

有趣的是,psumpbroadcast 轉置對應於使用者在使用 pmap 訓練 LLM 時引入的 psum_idrevid_psumrev

此系統如何解決效率低下的轉置範例#

再次考慮簡化的動機範例

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())

# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
  w:f32[]{i} = g(x)
  y:f32[]{} = psum(w, 'i')
  return y

有了這些新規則,轉置是

# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
              in_specs=P(), out_specs=P('i'))

# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
  wbar:f32[]{i} = pbroadcast(ybar, 'i')
  xbar:f32[3,4]{i} = transpose(g)(wbar)
  return xbar

其中評估 pbroadcast 應用程式根本不涉及任何通訊或 FLOP;它是一個空操作 (no-op)。請注意,如果我們持續轉置,主體的大小不會增加;實際上 t(t(f1)) == f1。效率提升了!

只要我們在需要的地方使用 pbroadcast 來檢查類型,我們也不會搞砸其他範例

# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))

# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))


# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.

# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())

直觀地說,在範例 1 中,我們現在只有「原始 psum 的一半」,而在範例 2 中,我們得到「兩半」。對於範例 3,我們根本不需要主體中的任何操作。

對於 all_gather 範例,範例 4 需要使用 all_reduce_invariant 才能進行有效率的轉置(儘管最好使用 out_specs 而不是主體中的 collective)

# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())

# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
  y:f32[8]{} = all_gather_invariant(x, 'i')
  return y

# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
  xbar:f32[1]{i} = pscatter(ybar, 'i')
  return xbar

對於範例 5,使用裝置變異 (device-varying) 的 all_gather 可以如我們所願地運作

# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
  z:f32[8]{i} = all_gather(x, 'i')
  w:f32[8]{i} = z * y
  return w

# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
  zbar:f32[8]{i} = wbar * y
  xbar:f32[1]{i} = psum_scatter(zbar, 'i')
  return xbar

如何使 API 對使用者來說更方便(並且向後相容)#

但是,哪個使用者想要編寫 pbroadcast?又有哪個開發人員想要破壞大量現有的使用者程式碼,這些程式碼涉及未饋送到未映射輸出 (unmapped output) 的 psum?我不想!

相反,我們可以自動插入 pbroadcast。這有點類似於我們在 jax.numpy 層執行自動階級提升 (rank promotion) 的方式,插入 broadcast 以避免二元運算符中的階級不匹配錯誤。但是,這要簡單得多,因為我們不需要處理形狀元組。典型的規則是:每當我們看到一個多元運算,其中運算元的裝置變異類型 (device variance type) 不一致時,取運算元裝置變異類型 (device variance type) 的軸名稱集合 (axis name set) 的聯集,並插入 pbroadcast 以將每個運算元提升到產生的裝置變異類型 (device variance type)。

在需要 pbroadcast 之前自動插入它們,可能意味著我們將相同的 pbroadcast 多次應用於相同的運算元,從而建立常見的子表達式。當我們轉置時,這些可能會變成 sum-of-psum 而不是 psum-of-sum。我們將依靠編譯器來適當地清理它。如果這是一個問題,那麼我們可以向 pbroadcast 插入通道添加一些簡單的記憶化 (memoization)。

用於 all_gather 的使用者 API 預設將表示 all_gather_p(而不是 all_gather_invariant_p),涵蓋常見情況,並且意味著不必插入 pbroadcast

我們可以在 shmap 上提供一個選項來停用 pbroadcast 的自動插入,在這種情況下,將由使用者來確保類型正確性。對於一些想要明確指出 psum 在反向傳遞中發生位置的人來說,這個明確的選項可能很有吸引力。

如何實作解決方案#

使實作輕量化的關鍵在於我們不會將這些類型添加到 avals 或 jaxprs 中。至少,一開始不會。這可能會很耗費資源,因為它需要更新 JAX 的其餘部分,例如,avals 和 jaxprs 的所有消費者可能都需要處理新類型。我們不會再犯同樣的錯誤!

相反,我們將把這些擴展類型保留為 shmap 內部的元數據,就像當前「用於 out_specs 的複製檢查」機制在 shmap 內部一樣。實際上,此解決方案相當於對現有機制進行相對較小的擴展:它已經在追蹤相同的資訊;現在我們只是添加了 pbroadcast

我們至少有兩種選擇可以在哪裡執行 pbroadcast 插入

  1. 就在轉置之前,在轉置規則中,我們有一個要轉置的計算的 jaxpr;

  2. 在每個 shmap 主體中,無論是立即執行還是階段性輸出,都像當前「用於 out_specs 的複製檢查」機制一樣。前者可能最終會更容易,因為我們只需要處理 jaxpr 的情況,而且只需要線性原語。但我們將首先嘗試後者,因此此處的實作是對現有複製檢查邏輯的嚴格修訂/擴展。

附錄:定義和激勵具有未映射輸入和輸出的映射#

為了具體起見,我們將主要關注 shmap,儘管這些相同的想法適用於例如 pmap,甚至可能是 xmap

in_specs 的對應條目沒有提及該網格軸的名稱時,參數/輸入沿網格軸是未映射的 (unmapped)。從邏輯上講,這意味著沿該網格軸的每個函數實例都獲得相同的參數值。對於呼叫者來說,每個運算元都根據運算元映射到的網格軸進行切片,而對於運算元未映射到的網格軸則不進行切片。

out_specs 的對應條目沒有提及該網格軸的名稱時,輸出沿網格軸是未映射的 (unmapped)。從邏輯上講,這意味著沿該網格軸的每個函數實例都必須傳回相同的值。對於呼叫者來說,shmap 的每個結果都是通過串連每個函數實例的傳回值形成的,而輸出映射到這些函數實例,而對於輸出未映射到的網格軸,僅使用該值的一個副本。

有關未映射輸入和輸出的範例,請參閱 shmap JEP。為了比較,在 vmap 中,未映射的輸入/輸出通過使用 in_axes / out_axesNone(而不是 int)來指示。

以下是我們喜歡 shmap 的未映射輸入和輸出的原因

  • pjit 相同的表達能力。 pjit 可以做的任何事情,shmap 的應急方案也應該能夠做到。否則我們的應急方案就會有所欠缺!如果我們在 shmap 中沒有未映射的輸出,那麼我們就無法表達與 pjit 相同的批次並行損失函數計算。

  • 閉包輸入 (Closed-over input)。 閉包輸入 (Closed-over input) 本質上對應於未映射的輸入,並且…

  • 轉置下的閉包 (Closure under transposition)。 一旦我們有了未映射的輸入,自然就可以轉置為未映射的輸出。

因此,未映射的輸出既是規範的,又是有用的!