JAX 內部機制:jaxpr 語言#

Jaxpr 是 JAX 程式的內部中介表示法 (IR)。它們是顯式型別、函數式、一階,且為代數正規形式 (ANF)。

概念上,可以將 JAX 轉換(例如 jax.jit()jax.grad())視為首先追蹤專門化要轉換的 Python 函數,使其成為一個小型且行為良好的中介形式,然後使用特定於轉換的解譯規則進行解譯。

JAX 能夠在如此小的軟體套件中包含如此強大的功能的原因之一,是它從熟悉且靈活的程式設計介面(Python 與 NumPy)開始,並使用實際的 Python 解譯器來完成大部分繁重的工作,將計算的本質提煉成一種簡單的靜態型別表示式語言,並具有有限的高階功能。

該語言就是 jaxpr 語言。jaxpr 項語法如下所示

jaxpr ::=
  { lambda <binder> , ... .
    let <eqn>
        ...
    in ( <atom> , ... ) }

binder ::= <var>:<array_type>
var ::= a | b | c | ...
atom ::= <var> | <literal>
literal ::= <int32> | <int64> | <float32> | <float64>

eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...

並非所有 Python 程式都可以這樣處理,但事實證明,許多科學計算和機器學習程式都可以。

在繼續之前,請記住,並非所有 JAX 轉換都會如上所述實際具體化 jaxpr。它們中的一些,例如微分或批次處理,將在追蹤期間逐步應用轉換。儘管如此,如果想要了解 JAX 的內部運作方式,或利用 JAX 追蹤的結果,了解 jaxpr 是很有用的。

jax.core.ClosedJaxpr#

一個 jaxpr 實例代表一個函數,它具有一個或多個型別參數(輸入變數)和一個或多個型別結果。結果僅取決於輸入變數;沒有從封閉作用域捕獲的自由變數。輸入和輸出具有型別,在 JAX 中以抽象值表示。

程式碼中對於 jaxpr 有兩種相關的表示形式,jax.core.Jaxprjax.core.ClosedJaxprjax.core.ClosedJaxpr 代表部分應用的 jax.core.Jaxpr,當您使用 jax.make_jaxpr() 檢查 jaxpr 時,您會獲得它。它具有以下欄位

  • jaxpr:是一個 jax.core.Jaxpr,代表函數的實際計算內容(如下所述)。

  • consts 是一個常數列表。

ClosedJaxpr 最有趣的部分是實際的執行內容,它以 jax.core.Jaxpr 的形式表示,並使用以下語法列印

jaxpr ::= { lambda Var* ; Var+.
            let Eqn*
            in  [Expr+] }

其中

  • jaxpr 的參數顯示為兩個變數列表,以 ; 分隔

    • 第一組變數是為了代表已經提升出來的常數而引入的變數。這些變數稱為 constvars,在 jax.core.ClosedJaxpr 中,consts 欄位保存對應的值。

    • 第二個變數列表,稱為 invars,對應於追蹤的 Python 函數的輸入。

  • Eqn* 是一個方程式列表,定義了指向中介表示式的中介變數。每個方程式定義一個或多個變數,作為在某些原子表示式上應用 primitive 的結果。每個方程式僅使用輸入變數和先前方程式定義的中介變數。

  • Expr+:是 jaxpr 的輸出原子表示式(字面值或變數)列表。

方程式列印如下

Eqn  ::= let Var+ = Primitive [ Param* ] Expr+

其中

  • Var+ 是一個或多個中介變數,將被定義為 primitive 調用的輸出(某些 primitive 可以返回多個值)。

  • Expr+ 是一個或多個原子表示式,每個可以是變數或字面常數。一個特殊的變數 unitvar 或字面值 unit,列印為 *,表示計算的其餘部分不需要的值,並且已被省略。也就是說,unit 只是佔位符。

  • Param* 是 primitive 的零個或多個具名參數,以方括號列印。每個參數都顯示為 Name = Value

大多數 jaxpr primitive 都是一階的(它們只接受一個或多個 Expr 作為引數)

Primitive := add | sub | sin | mul | ...

最常見的 jaxpr primitive 記錄在 jax.lax 模組中。

例如,以下是為下面的函數 func1 生成的 jaxpr

from jax import make_jaxpr
import jax.numpy as jnp

def func1(first, second):
   temp = first + jnp.sin(second) * 3.
   return jnp.sum(temp)

print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

這裡沒有 constvars,ab 是輸入變數,它們分別對應於 firstsecond 函數參數。純量字面值 3.0 保留為內聯。除了運算元 e 之外,reduce_sum primitive 還具有具名參數 axesinput_shape

請注意,即使調用 JAX 的程式執行會建構 jaxpr,Python 層級的控制流程和 Python 層級的函數也會正常執行。這表示僅僅因為 Python 程式包含函數和控制流程,產生的 jaxpr 不必包含控制流程或高階功能。

例如,當追蹤函數 func3 時,JAX 將內聯調用 inner 和條件式 if second.shape[0] > 4,並將產生與之前相同的 jaxpr

def func2(inner, first, second):
  temp = first + inner(second) * 3.
  return jnp.sum(temp)

def inner(second):
  if second.shape[0] > 4:
    return jnp.sin(second)
  else:
    assert False

def func3(first, second):
  return func2(inner, first, second)

print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

處理 pytrees#

在 jaxpr 中沒有元組型別;相反,primitive 接受多個輸入並產生多個輸出。當處理具有結構化輸入或輸出的函數時,JAX 將展平它們,並且在 jaxpr 中,它們將顯示為輸入和輸出的列表。有關更多詳細資訊,請參閱 Pytrees 教學。

例如,以下程式碼產生的 jaxpr 與您之前看到的相同(具有兩個輸入變數,每個變數對應於輸入元組的一個元素)

def func4(arg):  # The `arg` is a pair.
  temp = arg[0] + jnp.sin(arg[1]) * 3.
  return jnp.sum(temp)

print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

常數變數 (vars)#

jaxpr 中的某些值是常數,因為它們的值不取決於 jaxpr 的引數。當這些值是純量時,它們直接在 jaxpr 方程式中表示。非純量陣列常數則提升到頂層 jaxpr,在那裡它們對應於常數變數(“constvars”)。這些 constvars 與其他 jaxpr 參數(“invars”)的不同僅在於簿記慣例。

高階 JAX primitives#

Jaxpr 包含幾個高階 JAX primitive。它們更複雜,因為它們包含子 jaxpr。

cond primitive(條件式)#

JAX 會追蹤正常的 Python 條件式。若要捕獲用於動態執行的條件表示式,必須使用 jax.lax.switch()jax.lax.cond() 建構函式,它們具有以下簽名

lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B

lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B

這兩者都將在內部綁定一個名為 cond 的 primitive。jaxpr 中的 cond primitive 反映了 lax.switch() 的更通用簽名:它接受一個整數,表示要執行的分支的索引(鉗制到有效的索引範圍內)。

例如

from jax import lax

def one_of_three(index, arg):
  return lax.switch(index, [lambda x: x + 1.,
                            lambda x: x - 2.,
                            lambda x: x + 3.],
                    arg)

print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a:i32[] b:f32[]. let
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
    d:i32[] = clamp 0 c 2
    e:f32[] = cond[
      branches=(
        { lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
        { lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
        { lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
      )
    ] d b
  in (e,) }

cond primitive 有許多參數

  • branches 是對應於分支函數的 jaxpr。在本例中,這些函數各自接受一個輸入變數,對應於 x

  • linear 是一個布林元組,由自動微分機制在內部使用,以編碼哪些輸入參數在條件式中線性使用。

上面的 cond primitive 實例接受兩個運算元。第一個 (d) 是分支索引,然後 b 是運算元 (arg),將傳遞給 branches 中由分支索引選擇的任何 jaxpr。

另一個範例,使用 jax.lax.cond()

from jax import lax

def func7(arg):
  return lax.cond(arg >= 0.,
                  lambda xtrue: xtrue + 3.,
                  lambda xfalse: xfalse - 3.,
                  arg)

print(make_jaxpr(func7)(5.))
{ lambda ; a:f32[]. let
    b:bool[] = ge a 0.0
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    d:f32[] = cond[
      branches=(
        { lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) }
        { lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) }
      )
    ] c a
  in (d,) }

在本例中,布林述詞被轉換為整數索引(0 或 1),並且 branches 是對應於 false 和 true 分支函數的 jaxpr,依序排列。同樣,每個函數接受一個輸入變數,分別對應於 xfalsextrue

以下範例顯示了一種更複雜的情況,當分支函數的輸入是元組時,並且 false 分支函數包含一個常數 jnp.ones(1),該常數被提升為 constvar

def func8(arg1, arg2):  # Where `arg2` is a pair.
  return lax.cond(arg1 >= 0.,
                  lambda xtrue: xtrue[0],
                  lambda xfalse: jnp.array([1]) + xfalse[1],
                  arg2)

print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
    e:bool[] = ge b 0.0
    f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    g:f32[1] = cond[
      branches=(
        { lambda ; h:i32[1] i:f32[1] j:f32[]. let
            k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h
            l:f32[1] = add k j
          in (l,) }
        { lambda ; m_:i32[1] n:f32[1] o:f32[]. let  in (n,) }
      )
    ] f a c d
  in (g,) }

while primitive#

就像條件式一樣,Python 迴圈在追蹤期間是內聯的。如果想要捕獲用於動態執行的迴圈,則必須使用幾個特殊操作之一:jax.lax.while_loop()(一個 primitive)和 jax.lax.fori_loop()(一個生成 while_loop primitive 的輔助函數)

lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C

在上面的簽名中,C 代表迴圈「carry」值的型別。例如,以下是一個 fori_loop 範例

import numpy as np

def func10(arg, n):
  ones = jnp.ones(arg.shape)  # A constant.
  return lax.fori_loop(0, n,
                       lambda i, carry: carry + ones * 3. + arg,
                       arg + ones)

print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda ; a:f32[16] b:i32[]. let
    c:f32[16] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(16,)
      sharding=None
    ] 1.0
    d:f32[16] = add a c
    _:i32[] _:i32[] e:f32[16] = while[
      body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
          k:i32[] = add h 1
          l:f32[16] = mul f 3.0
          m:f32[16] = add j l
          n:f32[16] = add m g
        in (k, i, n) }
      body_nconsts=2
      cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let
          r:bool[] = lt o p
        in (r,) }
      cond_nconsts=0
    ] c a 0 b d
  in (e,) }

while primitive 接受 5 個引數:c a 0 b d,如下所示

  • 用於 cond_jaxpr 的 0 個常數(因為 cond_nconsts 為 0)

  • 用於 body_jaxpr 的 2 個常數(ca

  • 用於 carry 初始值的 3 個參數

scan primitive#

JAX 支援一種特殊形式的迴圈,用於迭代陣列的元素(具有靜態已知的形狀)。存在固定迭代次數的事實使得這種形式的迴圈易於反向微分。此類迴圈是使用 jax.lax.scan() 函數建構的

lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])

這是以 Haskell 型別簽名 的形式寫成的:Cscan carry 的型別,A 是輸入陣列的元素型別,B 是輸出陣列的元素型別。

對於範例,請考慮下面的函數 func11

def func11(arr, extra):
  ones = jnp.ones(arr.shape)  #  A constant
  def body(carry, aelems):
    # carry: running dot-product of the two arrays
    # aelems: a pair with corresponding elements from the two arrays
    ae1, ae2 = aelems
    return (carry + ae1 * ae2 + extra, carry)
  return lax.scan(body, 0., (arr, ones))

print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda ; a:f32[16] b:f32[]. let
    c:f32[16] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(16,)
      sharding=None
    ] 1.0
    d:f32[] e:f32[16] = scan[
      _split_transpose=False
      jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let
          j:f32[] = mul h i
          k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
          l:f32[] = add k j
          m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
          n:f32[] = add l m
        in (n, g) }
      length=16
      linear=(False, False, False, False)
      num_carry=1
      num_consts=1
      reverse=False
      unroll=1
    ] b 0.0 a c
  in (d, e) }

linear 參數描述了對於每個輸入變數,是否保證它們在 body 中線性使用。一旦 scan 完成線性化,更多引數將會是線性的。

scan primitive 接受 4 個引數:b 0.0 a c,其中

  • 一個是 body 的自由變數

  • 一個是 carry 的初始值

  • 接下來的 2 個是 scan 操作的陣列

(p)jit primitive#

call primitive 源自 JIT 編譯,它封裝了一個子 jaxpr 以及指定後端和應在其上執行計算的裝置的參數。例如

from jax import jit

def func12(arg):
  @jit
  def inner(x):
    return x + arg * jnp.ones(1)  # Include a constant in the inner function.
  return arg + inner(arg - 2.)

print(make_jaxpr(func12)(1.))
{ lambda ; a:f32[]. let
    b:f32[] = sub a 2.0
    c:f32[1] = pjit[
      name=inner
      jaxpr={ lambda ; d:f32[] e:f32[]. let
          f:f32[1] = broadcast_in_dim[
            broadcast_dimensions=()
            shape=(1,)
            sharding=None
          ] 1.0
          g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
          h:f32[1] = mul g f
          i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
          j:f32[1] = add i h
        in (j,) }
    ] a b
    k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
    l:f32[1] = add k c
  in (l,) }