形狀多型#

當 JAX 在 JIT 模式下使用時,會追蹤函式、將其降階為 StableHLO,並針對輸入類型和形狀的每種組合進行編譯。在匯出函式並在另一個系統上還原序列化之後,我們不再有可用的 Python 原始碼,因此我們無法重新追蹤和重新降階它。 形狀多型 是 JAX 匯出的一項功能,允許某些匯出的函式用於整個系列的輸入形狀。 這些函式在匯出期間會被追蹤和降階一次,並且 Exported 物件包含能夠針對許多具體輸入形狀編譯和執行函式所需的資訊。 我們透過在匯出時指定包含維度變數(符號形狀)的形狀來做到這一點,如下例所示

>>> import jax
>>> from jax import export
>>> from jax import numpy as jnp
>>> def f(x):  # f: f32[a, b]
...   return jnp.concatenate([x, x], axis=1)

>>> # We construct symbolic dimension variables.
>>> a, b = export.symbolic_shape("a, b")

>>> # We can use the symbolic dimensions to construct shapes.
>>> x_shape = (a, b)
>>> x_shape
(a, b)

>>> # Then we export with symbolic shapes:
>>> exp: export.Exported = export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct(x_shape, jnp.int32))
>>> exp.in_avals
(ShapedArray(int32[a,b]),)
>>> exp.out_avals
(ShapedArray(int32[a,2*b]),)

>>> # We can later call with concrete shapes (with a=3 and b=4), without re-tracing `f`.
>>> res = exp.call(np.ones((3, 4), dtype=np.int32))
>>> res.shape
(3, 8)

請注意,此類函式在每次調用時,仍會針對每個具體的輸入形狀按需重新編譯。 只有追蹤和降階會被儲存。

在上面的範例中,jax.export.symbolic_shape() 用於將符號形狀的字串表示形式解析為維度表達式物件(類型為 _DimExpr),這些物件可以取代整數常數來建構形狀。 維度表達式物件會多載大多數整數運算子,因此在大多數情況下,您可以像使用整數常數一樣使用它們。 有關更多詳細資訊,請參閱 使用維度變數進行計算

此外,我們提供了 jax.export.symbolic_args_specs(),可用於根據多型形狀規格建構 jax.ShapeDtypeStruct 物件的 pytree

>>> def f1(x, y): # x: f32[a, 1], y : f32[a, 4]
...  return x + y

>>> # Assuming you have some actual args with concrete shapes
>>> x = np.ones((3, 1), dtype=np.int32)
>>> y = np.ones((3, 4), dtype=np.int32)
>>> args_specs = export.symbolic_args_specs((x, y), "a, ...")
>>> exp = export.export(jax.jit(f1))(* args_specs)
>>> exp.in_avals
(ShapedArray(int32[a,1]), ShapedArray(int32[a,4]))

請注意,多型形狀規格 "a, ..." 如何包含預留位置 ...,以從引數 (x, y) 的具體形狀中填入。 預留位置 ... 代表 0 個或多個維度,而預留位置 _ 代表一個維度。 jax.export.symbolic_args_specs() 支援引數的 pytree,用於填入 dtype 和任何預留位置。 函式將建構引數規格的 pytree (jax.ShapeDtypeStruct),以符合傳遞給它的引數結構。 在一個規格應適用於多個引數的情況下,多型形狀規格可以是 pytree 前綴,如上述範例所示。 請參閱 選用參數如何與 pytree 引數匹配

形狀規格的一些範例

  • ("(b, _, _)", None) 可用於具有兩個引數的函式,第一個引數是 3D 陣列,其批次前導維度應為符號。 第一個引數的其他維度和第二個引數的形狀會根據實際引數進行特殊化。 請注意,如果第一個引數是 3D 陣列的 pytree,且都具有相同的前導維度,但可能具有不同的尾隨維度,則相同的規格也適用。 第二個引數的值 None 表示該引數不是符號的。 相當地,可以使用 ...

  • ("(batch, ...)", "(batch,)") 指定兩個引數具有匹配的前導維度,第一個引數的秩至少為 1,第二個引數的秩為 1。

形狀多型的正確性#

我們希望確信,對於任何適用的具體形狀,匯出的程式在編譯和執行時,會產生與原始 JAX 程式相同的結果。 更精確地說

對於任何 JAX 函式 f 和任何包含符號形狀的引數規格 arg_spec,以及任何形狀與 arg_spec 匹配的具體引數 arg

  • 如果 JAX 原生執行在具體引數上成功:res = f(arg)

  • 並且如果使用符號形狀匯出成功:exp = export.export(f)(arg_spec)

  • 那麼編譯和執行匯出將會成功,並產生相同的結果:res == exp.call(arg)

務必理解,f(arg) 可以自由地重新調用 JAX 追蹤機制,並且實際上它會針對每個不同的具體 arg 形狀執行此操作,而 exp.call(arg) 的執行無法再使用 JAX 追蹤(此執行可能發生在 f 的原始碼不可用的環境中)。

確保這種形式的正確性很困難,在最困難的情況下,匯出會失敗。 本章的其餘部分描述如何處理這些失敗。

使用維度變數進行計算#

JAX 會追蹤所有中間結果的形狀。 當這些形狀取決於維度變數時,JAX 會將它們計算為涉及維度變數的符號維度表達式。 維度變數代表大於或等於 1 的整數值。 符號表達式可以表示將算術運算子(add、sub、mul、floordiv、mod,包括 NumPy 變體 np.sumnp.prod 等)應用於維度表達式和整數intnp.int 或任何可透過 operator.index 轉換的物件)的結果。 這些符號維度隨後可用於 JAX 基本運算和 API 的形狀參數中,例如,在 jnp.reshapejnp.arange、切片索引等中。

例如,在以下程式碼中,為了展平 2D 陣列,計算 x.shape[0] * x.shape[1] 會將符號維度 4 * b 計算為新形狀

>>> f = lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],))
>>> arg_spec = jax.ShapeDtypeStruct(export.symbolic_shape("b, 4"), jnp.int32)
>>> exp = export.export(jax.jit(f))(arg_spec)
>>> exp.out_avals
(ShapedArray(int32[4*b]),)

可以將維度表達式明確地轉換為 JAX 陣列,使用 jnp.array(x.shape[0]) 甚至 jnp.array(x.shape)。 這些運算的結果可以用作常規 JAX 陣列,但不能再用作形狀中的維度,例如,在 reshape

>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))
>>> exp.call(jnp.arange(3, dtype=np.int32))
Array([3, 4, 5], dtype=int32)

>>> exp = export.export(jax.jit(lambda x: x.reshape(jnp.array(x.shape[0]) + 2)))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32))  
Traceback (most recent call last):
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].

當符號維度用於與非整數(例如,floatnp.floatnp.ndarray 或 JAX 陣列)進行算術運算時,它會自動使用 jnp.array 轉換為 JAX 陣列。 例如,在下面的函式中,x.shape[0] 的所有出現都會隱式轉換為 jnp.array(x.shape[0]),因為它們涉及與非整數純量或 JAX 陣列的運算

>>> exp = export.export(jax.jit(
...     lambda x: (5. + x.shape[0],
...                x.shape[0] - np.arange(5, dtype=jnp.int32),
...                x + x.shape[0] + jnp.sin(x.shape[0]))))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b"), jnp.int32))
>>> exp.out_avals
(ShapedArray(float32[], weak_type=True),
 ShapedArray(int32[5]),
 ShapedArray(float32[b], weak_type=True))

>>> exp.call(jnp.ones((3,), jnp.int32))
 (Array(8., dtype=float32, weak_type=True),
  Array([ 3, 2, 1, 0, -1], dtype=int32),
  Array([4.14112, 4.14112, 4.14112], dtype=float32, weak_type=True))

另一個典型的範例是計算平均值時(觀察 x.shape[0] 如何自動轉換為 JAX 陣列)

>>> exp = export.export(jax.jit(
...     lambda x: jnp.sum(x, axis=0) / x.shape[0]))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b, c"), jnp.int32))
>>> exp.call(jnp.arange(12, dtype=jnp.int32).reshape((3, 4)))
Array([4., 5., 6., 7.], dtype=float32)

形狀多型存在時的錯誤#

大多數 JAX 程式碼都假設 JAX 陣列的形狀是整數元組,但使用形狀多型時,某些維度可能是符號表達式。 這可能會導致許多錯誤。 例如,我們可能會遇到常見的 JAX 形狀檢查錯誤

>>> v, = export.symbolic_shape("v,")
>>> export.export(jax.jit(lambda x, y: x + y))(
...     jax.ShapeDtypeStruct((v,), dtype=np.int32),
...     jax.ShapeDtypeStruct((4,), dtype=np.int32))
Traceback (most recent call last):
TypeError: add got incompatible shapes for broadcasting: (v,), (4,).

>>> export.export(jax.jit(lambda x: jnp.matmul(x, x)))(
...     jax.ShapeDtypeStruct((v, 4), dtype=np.int32))
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (v,).

我們可以透過指定引數具有形狀 (v, v) 來修正上述 matmul 範例。

符號維度的比較僅部分支援#

在 JAX 內部,有許多涉及形狀的相等和不等比較,例如,用於執行形狀檢查,甚至用於為某些基本運算選擇實作。 比較的支援方式如下

  • 相等比較受到支援,但有一個注意事項:如果兩個符號維度在維度變數的所有估值下都表示相同的值,則相等比較會評估為 True,例如,對於 b + b == 2*b;否則,相等比較會評估為 False。 有關此行為重要後果的討論,請參閱下方

  • 不等比較始終是相等比較的否定。

  • 不等性比較僅部分支援,方式與部分相等比較類似。 但是,在這種情況下,我們會考慮到維度變數的範圍是嚴格正整數。 例如,b >= 1b >= 02 * a + b >= 3True,而 b >= 2a >= ba - b >= 0 則不確定,並導致例外。

在比較運算無法解析為布林值的情況下,我們會引發 InconclusiveDimensionOperation。 例如,

import jax
>>> export.export(jax.jit(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), dtype=np.int32))  # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'a + 1' >= 'b' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.

如果您確實遇到 InconclusiveDimensionOperation,您可以嘗試以下幾種策略

  • 如果您的程式碼使用內建的 maxmin,或 np.maxnp.min,那麼您可以將它們替換為 core.max_dimcore.min_dim,它們的效果是將不等性比較延遲到編譯時,屆時形狀會變得已知。

  • 嘗試使用 core.max_dimcore.min_dim 重寫條件式,例如,您可以使用 core.max_dim(d, 0) 而不是 d if d > 0 else 0

  • 嘗試重寫程式碼,使其較少依賴維度應為整數的事實,並依賴符號維度在大多數算術運算中以整數進行 duck-typing 的事實。 例如,使用 d + 5 而不是 int(d) + 5

  • 指定符號約束,如下所述。

使用者指定的符號約束#

預設情況下,JAX 假設所有維度變數的範圍都大於或等於 1 的值,並且它會嘗試從中推導出其他簡單的不等式,例如

  • a + 2 >= 3,

  • a * 2 >= 1,

  • a + b + c >= 3,

  • a // 4 >= 0a**2 >= 1 等等。

如果您變更符號形狀規格以新增維度大小的隱含約束,則可以避免某些不等性比較失敗。 例如,

  • 您可以使用維度 2*b 來約束其為偶數且大於或等於 2。

  • 您可以使用維度 b + 15 來約束其至少為 16。 例如,以下程式碼在沒有 + 15 部分的情況下會失敗,因為 JAX 會想要驗證切片大小最多與軸大小一樣大。

>>> _ = export.export(jax.jit(lambda x: x[0:16]))(
...    jax.ShapeDtypeStruct(export.symbolic_shape("b + 15"), dtype=np.int32))

此類隱含符號約束用於決定比較,並在編譯時檢查,如下下方所述。

您也可以指定明確的符號約束

>>> # Introduce dimension variable with constraints.
>>> a, b = export.symbolic_shape("a, b",
...                              constraints=("a >= b", "b >= 16"))
>>> _ = export.export(jax.jit(lambda x: x[:x.shape[1], :16]))(
...    jax.ShapeDtypeStruct((a, b), dtype=np.int32))

約束與隱含約束一起形成連詞。 您可以指定 >=<=== 約束。 目前,JAX 對於使用符號約束進行推理的支援有限

  • 您可以從變數大於或等於或小於或等於常數的形式的約束中獲得最大收益。 例如,從 a >= 16b >= 8 的約束中,我們可以推斷出 a + 2*b >= 32

  • 當約束涉及更複雜的表達式時,您會獲得有限的力量,例如,從 a >= b + 8 中,我們可以推斷出 a - b >= 8,但不能推斷出 a >= 9。 我們未來可能會在某種程度上改進這個領域。

  • 等式約束被視為重寫規則:每當遇到 == 左側的符號表達式時,它都會被重寫為右側的表達式。 例如,floordiv(a, b) == c 的工作原理是將所有出現的 floordiv(a, b) 替換為 c。 等式約束不得在左側的頂層包含加法或減法。 有效左側的範例包括 a * b4 * afloordiv(a + c, b)

>>> # Introduce dimension variable with equality constraints.
>>> a, b, c, d = export.symbolic_shape("a, b, c, d",
...                                    constraints=("a * b == c + d",))
>>> 2 * b * a
2*d + 2*c

>>> a * b * b
b*d + b*c

符號約束也有助於解決 JAX 推理機制中的限制。 例如,在下面的程式碼中,JAX 將嘗試證明切片大小 x.shape[0] % 3(即符號表達式 mod(b, 3))小於或等於軸大小(即 b)。 這對於 b 的所有嚴格正值都為真,但 JAX 的符號比較規則無法證明這一點。 因此,以下程式碼會引發錯誤

from jax import lax
>>> b, = export.symbolic_shape("b")
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct((b,), dtype=np.int32))  # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'b' >= 'mod(b, 3)' is inconclusive.
This error arises for comparison operations with shapes that
are non-constant, and the result of the operation cannot be represented as
a boolean value for all values of the symbolic dimensions involved.

這裡的一個選項是將程式碼限制為僅適用於 3 的倍數的軸大小(透過將形狀中的 b 替換為 3*b)。 這樣,JAX 將能夠將模數運算 mod(3*b, 3) 簡化為 0。 另一個選項是新增一個符號約束,其中包含 JAX 嘗試證明的確切不確定不等式

>>> b, = export.symbolic_shape("b",
...                            constraints=["b >= mod(b, 3)"])
>>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3)
>>> _ = export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct((b,), dtype=np.int32))

與隱含約束一樣,顯式符號約束在編譯時使用與下方說明的相同機制進行檢查。

符號維度範圍#

符號約束儲存在 jax.export.SymbolicScope 物件中,該物件是針對每次呼叫 jax.export.symbolic_shapes() 時隱式建立的。 您必須小心不要混合使用使用不同範圍的符號表達式。 例如,以下程式碼將會失敗,因為 a1a2 使用不同的範圍(由 jax.export.symbolic_shape() 的不同調用建立)

>>> a1, = export.symbolic_shape("a,")
>>> a2, = export.symbolic_shape("a,", constraints=("a >= 8",))

>>> a1 + a2  
Traceback (most recent call last):
ValueError: Invalid mixing of symbolic scopes for linear combination.
Expected  scope 4776451856 created at <doctest shape_poly.md[31]>:1:6 (<module>)
and found for 'a' (unknown) scope 4776979920 created at <doctest shape_poly.md[32]>:1:6 (<module>) with constraints:
  a >= 8

源自單次呼叫 jax.export.symbolic_shape() 的符號表達式共享一個範圍,並且可以在算術運算中混合使用。 結果也將共享相同的範圍。

您可以重複使用範圍

>>> a, = export.symbolic_shape("a,", constraints=("a >= 8",))
>>> b, = export.symbolic_shape("b,", scope=a.scope)  # Reuse the scope of `a`

>>> a + b  # Allowed
b + a

您也可以明確地建立範圍

>>> my_scope = export.SymbolicScope()
>>> c, = export.symbolic_shape("c", scope=my_scope)
>>> d, = export.symbolic_shape("d", scope=my_scope)
>>> c + d  # Allowed
d + c

JAX 追蹤使用部分以形狀為鍵的快取,並且如果符號形狀使用不同的範圍,即使它們的列印結果相同,也會被視為不同。

等式比較的注意事項#

對於 b + 1 == bb == 0 (在這種情況下,可以確定所有維度變數的值的維度都不同)以及對於 b == 1 和對於 a == b,相等性比較會傳回 False。這是不可靠的,我們應該引發 core.InconclusiveDimensionOperation,因為在某些估值下,結果應該是 True,而在其他估值下,結果應該是 False。我們選擇使相等性成為完全的,因此允許不可靠性,因為否則當雜湊維度表達式或包含它們的物件(形狀、core.AbstractValuecore.Jaxpr)時,我們可能會遇到偽造的錯誤(spurious errors)。除了雜湊錯誤之外,相等性的部分語意會導致以下表達式的錯誤 b == a or b == bb in [a, b],即使我們更改比較的順序,也可以避免錯誤。

即使使用這種相等性處理方式,if x.shape[0] != 1: raise NiceErrorMessage 形式的程式碼仍然是可靠的,但 if x.shape[0] != 1: return 1 形式的程式碼是不可靠的。

維度變數必須可以從輸入形狀中解出#

目前,在調用導出的物件時,傳遞維度變數值的唯一方法是透過陣列引數的形狀間接傳遞。例如,b 的值可以在呼叫點從 f32[b] 類型的第一個引數的形狀推斷出來。這在大多數使用情況下都運作良好,並且它反映了 JIT 函數的呼叫慣例。

有時您可能想要導出由整數值參數化的函數,該整數值決定程式中的某些形狀。例如,我們可能想要導出下面定義的函數 my_top_k,並由 k 的值參數化,這決定了結果的形狀。以下嘗試會導致錯誤,因為維度變數 k 無法從輸入 x: i32[4, 10] 的形狀導出

>>> def my_top_k(k, x):  # x: i32[4, 10], k <= 10
...   return lax.top_k(x, k)[0]  # : i32[4, 3]
>>> x = np.arange(40, dtype=np.int32).reshape((4, 10))

>>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`.
>>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x)
>>> exp_static_k.in_avals[0]
ShapedArray(int32[4,10])

>>> exp_static_k.out_avals[0]
ShapedArray(int32[4,3])

>>> # When calling the exported function we pass only the non-static arguments
>>> exp_static_k.call(x)
Array([[ 9,  8,  7],
       [19, 18, 17],
       [29, 28, 27],
       [39, 38, 37]], dtype=int32)

>>> # Now attempt to export with symbolic `k` so that we choose `k` after export.
>>> k, = export.symbolic_shape("k", constraints=["k <= 10"])
>>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x)  
Traceback (most recent call last):
UnexpectedDimVar: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments

未來,除了透過輸入形狀隱含地傳遞維度變數的值之外,我們可能會新增額外的機制來傳遞。同時,上述用例的權宜之計是用形狀為 (0, k) 的陣列替換函數參數 k,以便可以從陣列的輸入形狀導出 k。第一個維度為 0 是為了確保整個陣列為空,並且在我們調用導出的函數時不會產生效能損失。

>>> def my_top_k_with_dimensions(dimensions, x):  # dimensions: i32[0, k], x: i32[4, 10]
...   return my_top_k(dimensions.shape[1], x)
>>> exp = export.export(jax.jit(my_top_k_with_dimensions))(
...     jax.ShapeDtypeStruct((0, k), dtype=np.int32),
...     x)
>>> exp.in_avals
(ShapedArray(int32[0,k]), ShapedArray(int32[4,10]))

>>> exp.out_avals[0]
ShapedArray(int32[4,k])

>>> # When we invoke `exp` we must construct and pass an array of shape (0, k)
>>> exp.call(np.zeros((0, 3), dtype=np.int32), x)
Array([[ 9,  8,  7],
       [19, 18, 17],
       [29, 28, 27],
       [39, 38, 37]], dtype=int32)

另一個您可能會遇到錯誤的情況是,當某些維度變數確實出現在輸入形狀中,但在 JAX 目前無法解出的非線性表達式中時

>>> a, = export.symbolic_shape("a")
>>> export.export(jax.jit(lambda x: x.shape[0]))(
...    jax.ShapeDtypeStruct((a * a,), dtype=np.int32))  
Traceback (most recent call last):
ValueError: Cannot solve for values of dimension variables {'a'}.
We can only solve linear uni-variate constraints.
Using the following polymorphic shapes specifications: args[0].shape = (a^2,).
Unprocessed specifications: 'a^2' for dimension size args[0].shape[0].

形狀斷言錯誤#

JAX 假設維度變數的範圍嚴格為正整數,並且在為具體輸入形狀編譯程式碼時會檢查此假設。

例如,給定符號輸入形狀 (b, b, 2*d),當使用實際引數 arg 調用時,JAX 將產生程式碼來檢查以下斷言

  • arg.shape[0] >= 1

  • arg.shape[1] == arg.shape[0]

  • arg.shape[2] % 2 == 0

  • arg.shape[2] // 2 >= 1

例如,以下是我們在以形狀為 (3, 3, 5) 的引數調用導出時得到的錯誤

>>> def f(x):  # x: f32[b, b, 2*d]
...   return x
>>> exp = export.export(jax.jit(f))(
...     jax.ShapeDtypeStruct(export.symbolic_shape("b, b, 2*d"), dtype=np.int32))   
>>> exp.call(np.ones((3, 3, 5), dtype=np.int32))  
Traceback (most recent call last):
ValueError: Input shapes do not match the polymorphic shapes specification.
Division had remainder 1 when computing the value of 'd'.
Using the following polymorphic shapes specifications:
  args[0].shape = (b, b, 2*d).
Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), .
Please see https://jax.dev.org.tw/en/latest/export/shape_poly.html#shape-assertion-errors for more details.

這些錯誤發生在編譯之前的預處理步驟中。

偵錯#

首先,請參閱偵錯文件。此外,您可以偵錯形狀精煉,它在編譯時針對具有維度變數或多平台支援的模組調用。

如果在形狀精煉期間發生錯誤,您可以設定 JAX_DUMP_IR_TO 環境變數,以查看形狀精煉之前 HLO 模組的傾印(命名為 ..._before_refine_polymorphic_shapes.mlir)。這個模組應該已經具有靜態輸入形狀。

若要啟用所有形狀精煉階段的記錄,您可以設定環境變數 TF_CPP_VMODULE=refine_polymorphic_shapes=3 在 OSS 中(在 Google 內部,您傳遞 --vmodule=refine_polymorphic_shapes=3

# Log from python
JAX_DUMP_IR_TO=/tmp/export.dumps/ TF_CPP_VMODULE=refine_polymorphic_shapes=3 python tests/shape_poly_test.py ShapePolyTest.test_simple_unary -v=3