使用 Pallas 撰寫 TPU 核心#

本頁重點介紹在嘗試於 Google TPU 上執行 Pallas 核心時重要的細節。首先,TPU 後端仍處於實驗階段,僅接受 JAX NumPy 的子集。此外,為 TPU 撰寫高效能程式碼可能需要仔細思考硬體的原生功能。雖然許多對硬體來說不自然的模式將被接受,但它們最終可能需要軟體模擬,並可能減慢計算速度。

警告

此功能仍應被視為實驗性功能,因為工作仍在進行中(尤其是在改進錯誤訊息方面)。

注意

雖然此處描述的所有功能都是實驗性的,但我們仍然非常重視保持其正確性。因此,在嘗試撰寫 TPU 核心時,可能會看到「未實作」錯誤並不罕見。但是,如果核心被編譯器接受,它<強調>必須傳回預期的結果。

如果您看到非預期的輸出,請將它們與傳遞 interpret=Truepallas_call 執行的核心進行比較。如果結果發散,請提交錯誤報告

什麼是 TPU?#

A TPUv4 board

TPU 是 Google 開發的硬體加速器。您可以將 TPU 視為 GPU,但專門針對機器學習工作負載而設計。因此,它們的架構差異很大。但是,我們相信 Pallas 可以讓您輕鬆開始撰寫 TPU 核心,即使您沒有完全理解底層硬體。話雖如此,充分了解硬體肯定會讓您更容易撰寫高效能核心。

簡而言之,TPU 和 GPU 之間的主要區別在於 TPU 是具有非常寬的向量暫存器(有點像 CPU!)的循序機器。同時,它們允許軟體在背景中排程某些操作,使其相對於主要指令流非同步執行。這包括諸如 HBM 記憶體存取(無法直接發出,而是必須由 DMA 子單元預取到記憶體階層的較低層級)、矩陣乘法(由 MXU 單元支援)或矩陣轉置和排列(由 XLU 單元支援)之類的事情。

如果您有興趣詳細了解 TPU 架構,我們建議您閱讀多年來發表的一系列論文。雖然其中許多論文討論了特定的 TPU 世代,但其中描述的許多想法也適用於後來的世代。

值得注意的屬性和限制#

BlockSpec 和網格迭代#

BlockSpec(請參閱BlockSpec,又名如何將輸入分塊)通常在 Pallas 中的行為與預期相同 — 核心主體的每次調用都可以存取輸入的切片,並且旨在初始化輸出的切片。

注意

並非所有區塊形狀都受支援。在 TPU 上,僅支援秩至少為 1 的區塊

。此外,區塊形狀的最後兩個維度必須分別可被 8 和 128 整除,或者等於整體陣列的各自維度。

Pallas TPU 核心的一個有趣方面是它們處理記憶體空間的方式:雖然 pallas_call 的輸入通常駐留在 HBM(主 TPU 記憶體)中,但傳遞到核心主體的參考會指向記憶體階層較低層級(VMEM 或 SMEM)中的緩衝區。這使核心主體能夠以非常高的速度寫入和讀取它們,而所有與 HBM 的通訊(具有非常高的延遲)都由編譯器處理並與計算重疊。

更重要的是,與 GPU 相比,TPU 實際上是高度循序的機器。因此,網格通常不是平行處理的,而是循序地按字典順序處理(但請參閱多核心 TPU 配置章節以了解例外情況)。這解鎖了一些有趣的功能

  • 當兩個(按字典順序)連續的網格索引使用輸入的相同切片時,第二個迭代的 HBM 傳輸將被跳過,因為資料已可用。

  • 核心主體的多次調用可以寫入輸出的相同切片,而不會有任何競爭條件的風險。但是,我們確實要求所有寫入特定切片的調用都是連續的。

輸出上的「連續」限制通常意味著網格維度的某些前綴始終會改變調用需要存取的輸出切片,而輸出視窗對於剩餘的後綴保持不變。

例如,當為矩陣乘法實作 Pallas TPU 核心時,通常會使用 3 維網格:前兩個維度將對應於沿左運算元的第一個軸和第二個運算元的第二個軸進行切片。第三個也是<強調>最後一個網格軸將平鋪歸約維度。對應於歸約維度的網格軸必須是最後一個,因為輸出視窗沿此軸不會變化。然後,輸出參考可以用作部分結果的累加器。

注意

對於如此低階的記憶體階層(16MB+),VMEM 相當大,因此可以使用大的視窗大小。而且,通常情況下,視窗大小越大,最終的硬體利用率就越高。但是,可以指定一個視窗大小,該大小(連同保存溢出向量暫存器所需的空間)超過 VMEM 的大小。在這種情況下,您可能會看到一個低階編譯器錯誤訊息,抱怨記憶體不足錯誤。

陣列佈局#

陣列的維度順序在 Pallas 中是有意義的。在 JAX 程式中,jax.jit 內的中間陣列的順序通常對效能沒有影響,因為編譯器可以自由地重新排列它們。但是,由於 Pallas 旨在公開較低階的功能,因此維度順序可能會對產生的程式碼品質產生重大影響。

TPU 在 2D 向量暫存器上執行大部分計算,對於 32 位元值,其大小通常為 8x128(截至 TPU v6)。當從 VMEM 將向量值載入暫存器時(例如 x = x_ref[...]),陣列的最後兩個維度將被平鋪到暫存器中。Pallas 將只考慮將中間陣列的最後兩個維度映射到 8x128 向量暫存器維度(分別為子通道和通道)。

以下是如何使用 6 個 8x128 瓦片平鋪 12x320 陣列的圖形範例

../../_images/vector_layout_example.svg

平鋪佈局對核心撰寫者有幾個重要的影響

  • 陣列的最後兩個軸的處理方式與其他軸不同。例如,當涉及最後兩個軸時,歸約、重塑和轉置通常更昂貴。某些涉及最後兩個維度的重塑不受支援,將導致編譯器錯誤,但對於其他維度來說是「免費」的,並且在編譯時執行。

  • 雖然有時不可避免,但在最後兩個軸中具有單例維度通常是浪費的,因為它們將佔用整個瓦片維度中的 1 個元素。消耗過多的暫存器也可能導致暫存器溢出到 VMEM 中,從而降低核心效能。

  • 與上述觀點相關,所有向量計算都填充到瓦片大小。將兩個 1x1 陣列相加的成本與將兩個 8x128 陣列相加的成本相同,而將兩個 8x128x1x1 陣列相加的成本將是將兩個 8x128 陣列相加的 1024 倍,因為 8x128x1x1 陣列將被填充到 8x128x8x128。

多核心 TPU 配置#

在較新的 TPU 世代中,晶片上的兩個核心通常被抽象為單個裝置。為了利用多個核心,Pallas 必須打破循序網格執行保證,並且需要將其中一個網格軸平行化到核心上。這是一個選擇加入的程序。為了允許這樣做,pallas_call 需要一個名為 dimension_semantics 的額外參數

該參數是一個列表,其條目數與網格中的軸數相同。只有 parallel 維度可以跨核心進行分割。根據經驗法則,維度是平行的,除非輸出視窗不變化。因此,dimension_semantics 始終是若干個 parallel 軸,後跟若干個 arbitrary 軸。

雖然在 2 核心 TPU 裝置上分割核心通常會帶來 2 倍的加速,但實際上可能顯著更小。如果主體的不同實例具有高度不同的成本,則尤其如此。如果所有昂貴的步驟都映射到一個核心,但所有廉價的步驟都分配給另一個核心,則第二個核心將處於閒置狀態,直到第一個核心完成其任務。

Pallas TPU 通常傾向於分割大小為 TPU 核心數量的倍數的軸,並且傾向於分割前導網格軸。

將運算元放置在 SMEM 中#

TPU 上的大部分計算將在向量單元上進行。儘管如此,在許多情況下,執行許多純量運算很有用,例如,執行控制流程。為此,TPU 配備了單獨的純量單元和連接到它的單獨的純量記憶體 (SMEM)。根據經驗法則,用於執行控制流程決策的任何資料都應放置在 SMEM 中。

SMEM 是一種低延遲記憶體,支援隨機存取,但每次指令只允許您讀取和寫入 32 位元值(與 VMEM 交易的 4KBi 粒度相比非常小,但由於缺少對齊要求而更靈活!)。

當實作不以規則模式存取輸入瓦片的核心時,例如在撰寫區塊稀疏核心時,純量記憶體也非常有用。在 Pallas 中,這可以透過將 grid 引數替換為 pallas_call,並使用具有非零 num_scalar_prefetch 引數的 PrefetchScalarGridSpecgrid_spec 來實現。如果 num_scalar_prefetchn,則 pallas_call 的前 n 個引數將放置在 SMEM 中。不應為這些引數指定 BlockSpec。但是,所有後續引數的 BlockSpec 將不僅接收網格索引,還接收前導運算元的 SMEM 參考。

請參閱 純量預取和區塊稀疏計算 以取得有關使用此功能的範例。

支援的資料型別#

目前 Pallas TPU 支援以下資料型別

  • jnp.float32

  • jnp.bfloat16

  • jnp.int* (所有精度,除了 jnp.int4)

  • jnp.uint* (所有精度)

  • jnp.bool_

運算放置#

所有純量(即 0D)陣列都將儲存在純量暫存器中,並且對它們的操作將在純量核心上執行。所有其他操作(即使是單元素,但 1D+ 陣列)都將在向量核心上執行。

支援的操作#

矩陣乘法#

矩陣乘法始終以 float32 格式產生結果。如果您的輸入不是 float32,我們建議使用 lax.dot 並將 preferred_element_type 設定為 jnp.float32

當使用 lax.dot_general 時,可以將矩陣乘法運算元的最後兩個維度的轉置融合到運算中,這可以提高整體核心效能。

精確度控制#

Pallas TPU 降低器知道 jax.default_matmul_precision。為了獲得最佳效能(和最低精度),請使用 bfloat16。如果您關心數值準確性,您可能需要將精度設定為 float32

警告

即使您將 32 位元運算元傳遞到矩陣乘法中,除非請求 float32 精度,否則它們也會四捨五入為 bfloat16

轉置#

如果值至少有 4 個維度,則除了最後兩個軸之外的所有軸的任意轉置都是免費的。否則,僅實作最後兩個軸的轉置。請注意,最後兩個維度的某些轉置可以融合到矩陣乘法中。

存取記憶體#

可以讀取或更新參考的任意切片,但須遵守實作限制。目前,對於 32 位元寬的輸入沒有限制,但對於較窄的型別僅支援某些切片模式。在最後兩個維度中與 8 和 128 的倍數對齊並且長度是 8 和 128 的倍數的讀取和寫入始終受支援。

對向量記憶體的讀取和寫入通常發生在 (8, 128) 形狀的瓦片上。因此,當讀取或寫入至少具有兩個維度的參考時,當記憶體存取的基本偏移量具有可被平鋪整除的索引,並且讀取區域的大小是瓦片大小的倍數時,可以獲得最佳效能。

元素級操作#

支援許多元素級操作。值得注意的是,硬體通常僅支援使用 32 位元型別的元素級計算。當載入使用較低精度型別的運算元時,它們通常應在套用元素級操作之前提升為 32 位元型別。

值得注意的是,它們的成本可能<強調>顯著不同。因此,我們概述了三類支援的操作:廉價 (🟢)、中等 (🌕) 和昂貴 (🔴)。

操作

成本

jnp.add, +

🟢

jnp.sub, -

🟢

jnp.mul, *

🟢

/, //, %

🌕

jnp.max, jnp.min

🟢

jnp.where (select)

🟢

jnp.abs

🟢

|, ^, &, ~

🟢

<<, >>

🟢

比較 (==, …)

🟢

型別轉換 (.astype)

🟢

jnp.exp

🌕

jnp.tanh

🌕

jnp.pow

🌕

jnp.sin

🔴

jnp.cos

🔴

許多 JAX 函式是根據其他 JAX 基本運算實作的,因此此列表可能不全面。例如,jax.nn.relu 是根據比較和 jnp.where 實作的,也適用於 Pallas 核心。

陣列建構子#

支援所有常數陣列建構子 (jnp.ones, jnp.zeros, jnp.full)。

歸約#

支援 sum, max, min (用於浮點值) 歸約,以及用於布林值的 anyall。不支援整數歸約。

最後一個陣列維度上的歸約通常最慢。倒數第二個維度上的歸約速度更快,但仍然比前導維度上的歸約慢。

廣播#

廣播的效能特性與歸約非常相似。始終支援沿除最後兩個尾隨維度之外的所有維度進行廣播,並且是免費的。沿倒數第二個維度進行廣播速度較慢,而沿最後一個維度進行廣播速度最慢。

重塑#

與往常一樣,支援除最後兩個維度之外的所有維度中的重塑,並且是免費的。

當重塑可以修改陣列的最後兩個維度時,僅支援兩種情況:(1) 某些前導維度被展平到倒數第二個維度上,或者 (2) 它新增了剛被歸約移除的維度。

隨機數生成#

Pallas 支援來自 jax.random 模組中最常用的函數,例如 uniformnormalbernoulli。金鑰應為 threefry2x32 金鑰,這是 JAX 中的預設設定。金鑰可以直接傳遞到核心中,或在核心內部產生。

控制流程#

TPU 後端目前對控制流程的支援有限。目前支援的函數為 condfori_loopfor_loop。然而,迴圈基本單元目前在編譯期間會完全展開,因此請盡量保持迴圈迭代次數在合理的小範圍內。

過度使用控制流程可能會導致低階程式碼生成方面顯著的效能衰退,建議盡可能地將大量運算密集的作業塞進單一基本區塊中。