分析裝置記憶體#

注意

2023 年 5 月更新:我們建議使用Tensorboard 分析進行裝置記憶體分析。在進行分析後,開啟 Tensorboard 分析器的 memory_viewer 標籤頁,以取得更詳細且易於理解的裝置記憶體用量。

JAX 裝置記憶體分析器可讓我們探索 JAX 程式如何以及為何使用 GPU 或 TPU 記憶體。例如,它可以被用於

  • 找出在特定時間點哪些陣列和可執行檔在 GPU 記憶體中,或

  • 追蹤記憶體洩漏。

安裝#

JAX 裝置記憶體分析器會發出可以使用 pprof (google/pprof) 解譯的輸出。首先安裝 pprof,方法是依照其安裝指示。在撰寫本文時,安裝 pprof 需要先安裝 1.16+ 版本的 GoGraphviz,然後執行

go install github.com/google/pprof@latest

這會將 pprof 安裝為 $GOPATH/bin/pprof,其中 GOPATH 預設為 ~/go

注意

來自 google/pprofpprof 版本與作為 gperftools 套件一部分發行的同名舊工具不同。gperftools 版本的 pprof 無法與 JAX 搭配使用。

了解 JAX 程式如何使用 GPU 或 TPU 記憶體#

裝置記憶體分析器的常見用途是找出 JAX 程式為何使用大量 GPU 或 TPU 記憶體,例如,如果嘗試偵錯記憶體不足的問題。

若要將裝置記憶體分析擷取到磁碟,請使用jax.profiler.save_device_memory_profile()。例如,考慮以下 Python 程式

import jax
import jax.numpy as jnp
import jax.profiler

def func1(x):
  return jnp.tile(x, 10) * 0.5

def func2(x):
  y = func1(x)
  return y, jnp.tile(x, 10) + 1

x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)

z.block_until_ready()

jax.profiler.save_device_memory_profile("memory.prof")

如果我們先執行上述程式,然後執行

pprof --web memory.prof

pprof 會開啟一個網頁瀏覽器,其中包含裝置記憶體分析的下列視覺化,格式為呼叫圖

Device memory profiling example

呼叫圖是在進行每個即時緩衝區配置時,Python 堆疊的視覺化。例如,在這個特定案例中,視覺化顯示 func2 及其被呼叫者負責配置 76.30MB,其中 38.15MB 是在從 func1func2 的呼叫中配置的。如需有關如何解譯呼叫圖視覺化的詳細資訊,請參閱 pprof 文件

使用 jax.jit() 編譯的函式對裝置記憶體分析器而言是不透明的。也就是說,在 jit 編譯的函式內部配置的任何記憶體都將歸因於整個函式。

在範例中,呼叫 block_until_ready() 是為了確保 func2 在收集裝置記憶體分析之前完成。請參閱非同步分派以取得更多詳細資訊。

偵錯記憶體洩漏#

我們也可以使用 JAX 裝置記憶體分析器來追蹤記憶體洩漏,方法是使用 pprof 來視覺化在不同時間點取得的兩個裝置記憶體分析之間的記憶體用量變化。例如,考慮以下程式,該程式會將 JAX 陣列累積到不斷成長的 Python 列表中。

import jax
import jax.numpy as jnp
import jax.profiler

def afunction():
  return jax.random.normal(jax.random.key(77), (1000000,))

z = afunction()

def anotherfunc():
  arrays = []
  for i in range(1, 10):
    x = jax.random.normal(jax.random.key(42), (i, 10000))
    arrays.append(x)
    x.block_until_ready()
    jax.profiler.save_device_memory_profile(f"memory{i}.prof")

anotherfunc()

如果我們只視覺化執行結束時的裝置記憶體分析 (memory9.prof),則迴圈中每次迭代 anotherfunc 會累積更多裝置記憶體配置可能不明顯

pprof --web memory9.prof

Device memory profile at end of execution

afunction 內的大型但固定的配置主導了分析,但不會隨著時間增長。

透過使用 pprof--diff_base 功能來視覺化跨迴圈迭代的記憶體用量變化,我們可以識別出程式的記憶體用量為何隨著時間增加

pprof --web --diff_base memory1.prof memory9.prof

Device memory profile at end of execution

視覺化顯示記憶體增長可歸因於 anotherfunc 內對 normal 的呼叫。