提前降低和編譯#

JAX 提供多種轉換,例如 jax.jitjax.pmap,傳回一個在加速器或 CPU 上編譯和執行的函式。正如 JIT 縮寫所示,所有編譯都即時發生以供執行。

某些情況需要提前 (AOT) 編譯。當您想要在執行時間之前完全編譯,或者您想要控制編譯過程的不同部分何時發生時,JAX 為您提供了一些選項。

首先,讓我們回顧編譯的階段。假設 fjax.jit() 輸出的函式/可呼叫物件,例如 f = jax.jit(F) 用於某些輸入可呼叫物件 F。當使用引數呼叫它時,例如 f(x, y),其中 xy 是陣列,JAX 依序執行以下操作

  1. 階段性輸出原始 Python 可呼叫物件 F 的專門版本到內部表示。此專門化反映了將 F 限制為從引數 xy 的屬性(通常是其形狀和元素型別)推斷出的輸入型別。

  2. 降低此專門化、階段性輸出的計算到 XLA 編譯器的輸入語言 StableHLO。

  3. 編譯降低後的 HLO 程式以產生目標裝置(CPU、GPU 或 TPU)的優化可執行檔。

  4. 執行使用陣列 xy 作為引數的編譯後可執行檔。

JAX 的 AOT API 讓您可以直接控制步驟 #2、#3 和 #4(但不是 #1),以及沿途的其他一些功能。一個範例

>>> import jax

>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4

>>> lowered = jax.jit(f).lower(x, y)

>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<2> : tensor<i32>
    %0 = stablehlo.multiply %c, %arg0 : tensor<i32>
    %1 = stablehlo.add %0, %arg1 : tensor<i32>
    return %1 : tensor<i32>
  }
}

>>> compiled = lowered.compile()

>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()['flops']
2.0

>>> # Execute the compiled function!
>>> compiled(x, y)
Array(10, dtype=int32, weak_type=True)

請注意,降低後的物件只能在降低它們的同一個進程中使用。對於匯出用例,請參閱匯出和序列化 API。

有關降低和編譯函式提供的更多功能詳細資訊,請參閱 jax.stages 文件。

jit 的所有可選引數(例如 static_argnums)在相應的降低、編譯和執行中都會受到尊重。

在上面的範例中,我們可以將 lower 的引數替換為任何具有 shapedtype 屬性的物件

>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
Array(10, dtype=int32)

更一般而言,lower 只需要其引數在結構上提供 JAX 為了專門化和降低而必須知道的資訊。對於像上面這樣的典型陣列引數,這表示 shapedtype 欄位。相反,對於靜態引數,JAX 需要實際的陣列值(更多資訊請參閱下方)。

使用與其降低不相容的引數調用 AOT 編譯的函式會引發錯誤

>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]

>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]

相關地,AOT 編譯的函式無法被 JAX 的即時轉換(例如 jax.jitjax.grad()jax.vmap())轉換。

使用靜態引數降低#

使用靜態引數降低強調了傳遞給 jax.jit 的選項、傳遞給 lower 的引數,以及調用產生的編譯函式所需的引數之間的交互作用。繼續上面的範例

>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)

>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<14> : tensor<i32>
    %0 = stablehlo.add %c, %arg0 : tensor<i32>
    return %0 : tensor<i32>
  }
}

>>> lowered_with_x.compile()(5)
Array(19, dtype=int32, weak_type=True)

lower 的結果直接序列化以在不同進程中使用是不安全的。有關此目的的其他 API,請參閱匯出和序列化

請注意,此處的 lower 像往常一樣接受兩個引數,但後續的編譯函式僅接受剩餘的非靜態第二個引數。靜態第一個引數(值 7)在降低時被視為常數,並建置到降低後的計算中,在那裡它可能會與其他常數合併。在這種情況下,它與 2 的乘法被簡化,產生常數 14。

儘管上面的 lower 的第二個引數可以替換為空心形狀/dtype 結構,但靜態第一個引數必須是具體值。否則,降低會出錯

>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)  
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'

>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
Array(25, dtype=int32)

AOT 編譯的函式無法轉換#

編譯後的函式專門用於一組特定的引數「型別」,例如具有特定形狀和元素型別的陣列,在我們的執行範例中。從 JAX 的內部觀點來看,諸如 jax.vmap() 之類的轉換會以使為編譯型別簽名失效的方式更改函式的型別簽名。作為一項政策,JAX 只是不允許編譯函式參與轉換。範例

>>> def g(x):
...   assert x.shape == (3, 2)
...   return x @ jnp.ones(2)

>>> def make_z(*shape):
...   return jnp.arange(np.prod(shape)).reshape(shape)

>>> z, zs = make_z(3, 2), make_z(4, 3, 2)

>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).lower(z).compile()

>>> jax.vmap(g_jit)(zs)
Array([[ 1.,  5.,  9.],
       [13., 17., 21.],
       [25., 29., 33.],
       [37., 41., 45.]], dtype=float32)

>>> jax.vmap(g_aot)(zs)  
Traceback (most recent call last):
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>

g_aot 涉及自動微分時(例如 jax.grad()),也會引發類似的錯誤。為了保持一致性,也禁止通過 jax.jit 進行轉換,即使 jit 並未有意義地修改其引數的型別簽名。

除錯資訊和分析(若可用)#

除了主要的 AOT 功能(分離和明確的降低、編譯和執行)之外,JAX 的各種 AOT 階段還提供了一些額外功能,以幫助進行除錯和收集編譯器回饋。

例如,正如上面的初始範例所示,降低後的函式通常提供文字表示。編譯後的函式也執行相同的操作,並且還提供來自編譯器的成本和記憶體分析。所有這些都是通過 jax.stages.Loweredjax.stages.Compiled 物件上的方法提供的(例如,上面的 lowered.as_text()compiled.cost_analysis())。您可以使用 debug_info 參數傳遞給 lowered.as_text() 來獲得更多除錯資訊,例如原始碼位置。

這些方法旨在作為人工檢查和除錯的輔助工具,而不是作為可靠的可程式化 API。它們的可用性和輸出因編譯器、平台和執行階段而異。這造成了兩個重要的注意事項

  1. 如果某些功能在 JAX 當前後端上不可用,則其方法會傳回一些微不足道的東西(和類似 False 的東西)。例如,如果 JAX 底層的編譯器不提供成本分析,則 compiled.cost_analysis() 將為 None

  2. 如果某些功能可用,則對於相應方法提供的內容仍然存在非常有限的保證。不需要傳回值在 JAX 配置、後端/平台、版本甚至方法調用之間保持類型、結構或值的一致性。JAX 無法保證 compiled.cost_analysis() 在一天的輸出在第二天將保持不變。

如有疑問,請參閱 jax.stages 的套件 API 文件。

檢查階段性計算#

本筆記頂部列表中的階段 #1 提到了專門化和階段性輸出,在降低之前。JAX 內部對專門用於其引數型別的函式的概念並不總是記憶體中的具體化資料結構。若要顯式建構 JAX 在內部 Jaxpr 中間語言中對函式專門化的視圖,請參閱 jax.make_jaxpr()