錯誤#

本頁列出使用 JAX 時可能遇到的一些錯誤,以及如何修正這些錯誤的代表性範例。

class jax.errors.ConcretizationTypeError(tracer, context='')#

當 JAX Tracer 物件在需要具體值的環境中使用時,就會發生此錯誤(如需更多關於 Tracer 的資訊,請參閱不同種類的 JAX 值)。在某些情況下,可以透過將有問題的值標記為靜態來輕鬆修正;在其他情況下,可能表示您的程式正在執行 JAX 的 JIT 編譯模型不直接支援的操作。

範例

預期靜態值時卻使用了追蹤值

此錯誤的一個常見原因是,在需要靜態值的地方使用了追蹤值。例如:

>>> from functools import partial
>>> from jax import jit
>>> import jax.numpy as jnp
>>> @jit
... def func(x, axis):
...   return x.min(axis)
>>> func(jnp.arange(4), 0)  
Traceback (most recent call last):
    ...
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: axis argument to jnp.min().

通常可以透過將有問題的引數標記為靜態來修正此問題

>>> @partial(jit, static_argnums=1)
... def func(x, axis):
...   return x.min(axis)

>>> func(jnp.arange(4), 0)
Array(0, dtype=int32)
形狀取決於追蹤值

當 JIT 編譯的計算中的形狀取決於追蹤數量內的值時,也可能發生這種錯誤。例如:

>>> @jit
... def func(x):
...     return jnp.where(x < 0)

>>> func(jnp.arange(4))  
Traceback (most recent call last):
    ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
The error arose in jnp.nonzero.

這是一個與 JAX 的 JIT 編譯模型不相容的操作範例,該模型要求陣列大小在編譯時就已知。在這裡,傳回陣列的大小取決於 x 的內容,而此類程式碼無法進行 JIT 編譯。

在許多情況下,可以透過修改函數中使用的邏輯來解決此問題;例如,以下是具有類似問題的程式碼:

>>> @jit
... def func(x):
...     indices = jnp.where(x > 1)
...     return x[indices].sum()

>>> func(jnp.arange(4))  
Traceback (most recent call last):
    ...
ConcretizationTypeError: Abstract tracer value encountered where concrete
value is expected: The error arose in jnp.nonzero.

以下是如何以避免建立動態大小索引陣列的方式,表達相同的操作:

>>> @jit
... def func(x):
...   return jnp.where(x > 1, x, 0).sum()

>>> func(jnp.arange(4))
Array(5, dtype=int32)

若要更深入瞭解與追蹤器與常規值以及具體值與抽象值相關的細微差異,您可能需要閱讀不同種類的 JAX 值

參數:
  • tracer (core.Tracer)

  • context (str)

class jax.errors.KeyReuseError(message)#

當 PRNG 金鑰以不安全的方式重複使用時,就會發生此錯誤。金鑰重複使用僅在 jax_debug_key_reuse 設定為 True 時才會檢查。

以下是一個簡單的程式碼範例,會導致此類錯誤:

>>> with jax.debug_key_reuse(True):  
...   key = jax.random.key(0)
...   value = jax.random.uniform(key)
...   new_value = jax.random.uniform(key)
...
---------------------------------------------------------------------------
KeyReuseError                             Traceback (most recent call last)
...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0

這種金鑰重複使用是有問題的,因為 JAX PRNG 是無狀態的,而且金鑰必須手動分割;如需更多資訊,請參閱虛擬隨機數教學課程

參數:

message (str)

jax.errors.JaxRuntimeError#

XlaRuntimeError 的別名

class jax.errors.NonConcreteBooleanIndexError(tracer)#

當程式嘗試在追蹤的索引操作中使用非具體的布林索引時,就會發生此錯誤。在 JIT 編譯下,JAX 陣列必須具有靜態形狀(即編譯時已知的形狀),因此必須謹慎使用布林遮罩。透過布林遮罩實作的某些邏輯在 jax.jit() 函數中根本不可能實現;在其他情況下,邏輯可以以 JIT 相容的方式重新表達,通常使用 where() 的三引數版本。

以下是一些可能發生此錯誤的範例。

透過布林遮罩建構陣列

當嘗試在 JIT 環境中透過布林遮罩建立陣列時,最常發生這種情況。例如:

>>> import jax
>>> import jax.numpy as jnp

>>> @jax.jit
... def positive_values(x):
...   return x[x > 0]

>>> positive_values(jnp.arange(-5, 5))  
Traceback (most recent call last):
    ...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])

此函數嘗試僅傳回輸入陣列中的正值;除非將 x 標記為靜態,否則無法在編譯時確定此傳回陣列的大小,因此此類操作無法在 JIT 編譯下執行。

可重新表達的布林邏輯

雖然不直接支援建立動態大小的陣列,但在許多情況下,可以根據 JIT 相容的操作重新表達計算的邏輯。例如,以下是另一個由於相同原因而在 JIT 下失敗的函數:

>>> @jax.jit
... def sum_of_positive(x):
...   return x[x > 0].sum()

>>> sum_of_positive(jnp.arange(-5, 5))  
Traceback (most recent call last):
    ...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])

然而,在這種情況下,有問題的陣列僅是一個中間值,我們可以改用 JIT 相容的三引數版本 jax.numpy.where() 來表達相同的邏輯:

>>> @jax.jit
... def sum_of_positive(x):
...   return jnp.where(x > 0, x, 0).sum()

>>> sum_of_positive(jnp.arange(-5, 5))
Array(10, dtype=int32)

這種將布林遮罩替換為三引數 where() 的模式是解決此類問題的常見方法。

布林索引到 JAX 陣列中

另一個經常發生此錯誤的情況是使用布林索引,例如使用 .at[...].set(...)。以下是一個簡單的範例:

>>> @jax.jit
... def manual_clip(x):
...   return x.at[x < 0].set(0)

>>> manual_clip(jnp.arange(-2, 2))  
Traceback (most recent call last):
    ...
NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])

此函數嘗試將小於零的值設定為純量填滿值。如上所述,可以透過使用 where() 重新表達邏輯來解決此問題:

>>> @jax.jit
... def manual_clip(x):
...   return jnp.where(x < 0, 0, x)

>>> manual_clip(jnp.arange(-2, 2))
Array([0, 0, 0, 1], dtype=int32)
參數:

tracer (core.Tracer)

class jax.errors.TracerArrayConversionError(tracer)#

當程式嘗試將 JAX Tracer 物件轉換為標準 NumPy 陣列時,就會發生此錯誤(如需更多關於 Tracer 的資訊,請參閱不同種類的 JAX 值)。它通常在以下幾種情況下發生。

在 JAX 轉換中使用非 JAX 函數

如果您嘗試在 JAX 轉換(jit()grad()jax.vmap() 等)內使用非 JAX 程式庫(如 numpyscipy)時,可能會發生此錯誤。例如:

>>> from jax import jit
>>> import numpy as np

>>> @jit
... def func(x):
...   return np.sin(x)

>>> func(np.arange(4))  
Traceback (most recent call last):
    ...
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on traced array with shape int32[4]

在這種情況下,您可以透過使用 jax.numpy.sin() 取代 numpy.sin() 來修正問題:

>>> import jax.numpy as jnp
>>> @jit
... def func(x):
...   return jnp.sin(x)

>>> func(jnp.arange(4))
Array([0.        , 0.84147096, 0.9092974 , 0.14112   ], dtype=float32)

另請參閱外部回呼,瞭解從轉換後的 JAX 程式碼回呼到主機端計算的選項。

使用追蹤器索引 NumPy 陣列

如果此錯誤發生在涉及陣列索引的程式碼行上,則可能是被索引的陣列 x 是標準 numpy.ndarray,而索引 idx 是追蹤的 JAX 陣列。例如:

>>> x = np.arange(10)

>>> @jit
... def func(i):
...   return x[i]

>>> func(0)  
Traceback (most recent call last):
    ...
TracerArrayConversionError: The numpy.ndarray conversion method
__array__() was called on traced array with shape int32[0]

根據上下文,您可以透過將 NumPy 陣列轉換為 JAX 陣列來修正此問題:

>>> @jit
... def func(i):
...   return jnp.asarray(x)[i]

>>> func(0)
Array(0, dtype=int32)

或透過將索引宣告為靜態引數:

>>> from functools import partial
>>> @partial(jit, static_argnums=(0,))
... def func(i):
...   return x[i]

>>> func(0)
Array(0, dtype=int32)

若要更深入瞭解與追蹤器與常規值以及具體值與抽象值相關的細微差異,您可能需要閱讀不同種類的 JAX 值

參數:

tracer (core.Tracer)

class jax.errors.TracerBoolConversionError(tracer)#

當 JAX 中的追蹤值在預期布林值的環境中使用時,就會發生此錯誤(如需更多關於 Tracer 的資訊,請參閱不同種類的 JAX 值)。

布林轉換可能是顯式的(例如 bool(x))或隱式的,透過控制流程的使用(例如 if x > 0while x)、Python 布林運算子的使用(例如 z = x and yz = x or yz = not x)或使用它們的函數(例如 z = max(x, y)z = min(x, y) 等)。

在某些情況下,可以透過將追蹤值標記為靜態來輕鬆修正此問題;在其他情況下,可能表示您的程式正在執行 JAX 的 JIT 編譯模型不直接支援的操作。

範例

在控制流程中使用的追蹤值

經常發生這種情況的一種情況是,追蹤值用於 Python 控制流程中。例如:

>>> from jax import jit
>>> import jax.numpy as jnp
>>> @jit
... def func(x, y):
...   return x if x.sum() < y.sum() else y

>>> func(jnp.ones(4), jnp.zeros(4))  
Traceback (most recent call last):
    ...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]

我們可以將輸入 xy 都標記為靜態,但這會失去在此處使用 jax.jit() 的目的。另一種選擇是以三項式 jax.numpy.where() 重新表達 if 陳述式:

>>> @jit
... def func(x, y):
...   return jnp.where(x.sum() < y.sum(), x, y)

>>> func(jnp.ones(4), jnp.zeros(4))
Array([0., 0., 0., 0.], dtype=float32)

對於包括迴圈在內更複雜的控制流程,請參閱控制流程運算子

追蹤值上的控制流程

另一個導致此錯誤的常見原因是,您不小心追蹤了布林旗標。例如:

>>> @jit
... def func(x, normalize=True):
...   if normalize:
...     return x / x.sum()
...   return x

>>> func(jnp.arange(5), True)  
Traceback (most recent call last):
    ...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...

在這裡,由於旗標 normalize 被追蹤,因此無法在 Python 控制流程中使用。在這種情況下,最佳解決方案可能是將此值標記為靜態:

>>> from functools import partial
>>> @partial(jit, static_argnames=['normalize'])
... def func(x, normalize=True):
...   if normalize:
...     return x / x.sum()
...   return x

>>> func(jnp.arange(5), True)
Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float32)

如需更多關於 static_argnums 的資訊,請參閱 jax.jit() 的文件。

使用非 JAX 感知函數

導致此錯誤的另一個常見原因是,在 JAX 程式碼中使用非 JAX 感知函數。例如:

>>> @jit
... def func(x):
...   return min(x, 0)
>>> func(2)  
Traceback (most recent call last):
    ...
TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...

在這種情況下,發生錯誤的原因是 Python 的內建 min 函數與 JAX 轉換不相容。可以透過將其替換為 jnp.minimum 來修正此問題:

>>> @jit
... def func(x):
...   return jnp.minimum(x, 0)
>>> print(func(2))
0

若要更深入瞭解與追蹤器與常規值以及具體值與抽象值相關的細微差異,您可能需要閱讀不同種類的 JAX 值

參數:

tracer (core.Tracer)

class jax.errors.TracerIntegerConversionError(tracer)#

當 JAX Tracer 物件在預期 Python 整數的環境中使用時,可能會發生此錯誤(如需更多關於 Tracer 的資訊,請參閱不同種類的 JAX 值)。它通常在以下幾種情況下發生。

傳遞追蹤器以取代整數

如果您嘗試將追蹤值傳遞給需要靜態整數引數的函數,可能會發生此錯誤;例如:

>>> from jax import jit
>>> import numpy as np

>>> @jit
... def func(x, axis):
...   return np.split(x, 2, axis)

>>> func(np.arange(4), 0)  
Traceback (most recent call last):
    ...
TracerIntegerConversionError: The __index__() method was called on
traced array with shape int32[0]

當發生這種情況時,解決方案通常是將有問題的引數標記為靜態:

>>> from functools import partial
>>> @partial(jit, static_argnums=1)
... def func(x, axis):
...   return np.split(x, 2, axis)

>>> func(np.arange(10), 0)
[Array([0, 1, 2, 3, 4], dtype=int32),
 Array([5, 6, 7, 8, 9], dtype=int32)]

另一種選擇是將轉換套用至封閉包裝,封閉包裝封裝了要保護的引數,可以手動執行如下,或使用 functools.partial()

>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4))
[Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]

請注意,每次調用時都會建立新的封閉包裝,這會破壞編譯快取機制,這就是為什麼靜態引數編號是首選的原因。

使用 Tracer 索引列表

如果您嘗試使用追蹤數量索引 Python 列表,可能會發生此錯誤。例如:

>>> import jax.numpy as jnp
>>> from jax import jit

>>> L = [1, 2, 3]

>>> @jit
... def func(i):
...   return L[i]

>>> func(0)  
Traceback (most recent call last):
    ...
TracerIntegerConversionError: The __index__() method was called on
traced array with shape int32[0]

根據上下文,您通常可以透過將列表轉換為 JAX 陣列來修正此問題:

>>> @jit
... def func(i):
...   return jnp.array(L)[i]

>>> func(0)
Array(1, dtype=int32)

或透過將索引宣告為靜態引數:

>>> from functools import partial
>>> @partial(jit, static_argnums=0)
... def func(i):
...   return L[i]

>>> func(0)
Array(1, dtype=int32, weak_type=True)

若要更深入瞭解與追蹤器與常規值以及具體值與抽象值相關的細微差異,您可能需要閱讀不同種類的 JAX 值

參數:

tracer (core.Tracer)

class jax.errors.UnexpectedTracerError(msg)#

當您使用已從函數洩漏出來的 JAX 值時,就會發生此錯誤。洩漏值是什麼意思?如果您對函數 f 使用 JAX 轉換,而該函數在 f 外部的某些範圍中儲存了對中間值的參照,則該值會被視為已洩漏。洩漏值是一種副作用。(如需更多關於避免副作用的資訊,請參閱純函數

當您稍後在另一個操作中使用洩漏的值時,JAX 會偵測到洩漏,此時它會引發 UnexpectedTracerError。若要修正此問題,請避免副作用:如果函數計算了外部範圍中需要的值,請從轉換後的函數中明確傳回該值。

具體來說,Tracer 是 JAX 在轉換期間函數的中間值的內部表示,例如在 jit()pmap()vmap() 等中。在轉換外部遇到 Tracer 表示洩漏。

洩漏值的生命週期

考慮以下已轉換函數的範例,該函數將值洩漏到外部範圍:

>>> from jax import jit
>>> import jax.numpy as jnp

>>> outs = []
>>> @jit                   # 1
... def side_effecting(x):
...   y = x + 1            # 3
...   outs.append(y)       # 4

>>> x = 1
>>> side_effecting(x)      # 2
>>> outs[0] + 1            # 5  
Traceback (most recent call last):
    ...
UnexpectedTracerError: Encountered an unexpected tracer.

在此範例中,我們將追蹤值從內部轉換範圍洩漏到外部範圍。當使用洩漏的值時,我們會收到 UnexpectedTracerError,而不是在值洩漏時。

此範例也示範了洩漏值的生命週期:

  1. 函數已轉換(在此案例中,由 jit() 轉換)

  2. 已調用轉換後的函數(啟動函數的抽象追蹤,並將 x 轉換為 Tracer

  3. 已建立中間值 y,稍後將洩漏(追蹤函數的中間值也是 Tracer

  4. 值已洩漏(附加到外部範圍中的列表,透過側通道逸出函數)

  5. 已使用洩漏的值,並引發 UnexpectedTracerError。

UnexpectedTracerError 訊息嘗試透過包含關於每個階段的資訊來指向程式碼中的這些位置。分別是:

  1. 已轉換函數的名稱 (side_effecting) 以及哪個轉換啟動了追蹤 jit()

  2. 已重建洩漏的 Tracer 建立位置的堆疊追蹤,其中包括調用已轉換函數的位置。( Tracer 建立時,最後 5 個堆疊框架為...)。

  3. 從重建的堆疊追蹤中,建立洩漏的 Tracer 的程式碼行。

  4. 錯誤訊息中未包含洩漏位置,因為很難精確指出!JAX 只能告訴您洩漏的值看起來像什麼(它具有什麼形狀以及在何處建立),以及它洩漏到哪個邊界之外(轉換的名稱和已轉換函數的名稱)。

  5. 目前錯誤的堆疊追蹤指向使用值的位置。

可以透過從轉換後的函數中傳回值來修正錯誤:

>>> from jax import jit
>>> import jax.numpy as jnp

>>> outs = []
>>> @jit
... def not_side_effecting(x):
...   y = x+1
...   return y

>>> x = 1
>>> y = not_side_effecting(x)
>>> outs.append(y)
>>> outs[0] + 1  # all good! no longer a leaked value.
Array(3, dtype=int32, weak_type=True)
洩漏檢查器

如上面第 2 點和第 3 點所述,JAX 顯示了指向洩漏值建立位置的重建堆疊追蹤。這是因為 JAX 僅在使用洩漏的值時才會引發錯誤,而不是在值洩漏時。這不是引發此錯誤最有用的位置,因為您需要知道 Tracer 洩漏的位置才能修正錯誤。

為了更容易追蹤此位置,您可以使用洩漏檢查器。啟用洩漏檢查器後,一旦 Tracer 洩漏,就會引發錯誤。(更準確地說,它會在洩漏 Tracer 的轉換函數傳回時引發錯誤)

若要啟用洩漏檢查器,您可以使用 JAX_CHECK_TRACER_LEAKS 環境變數或 with jax.checking_leaks() 上下文管理器。

注意

請注意,此工具是實驗性的,可能會報告誤報。它透過停用某些 JAX 快取來運作,因此會對效能產生負面影響,應僅在除錯時使用。

使用範例

>>> from jax import jit
>>> import jax.numpy as jnp

>>> outs = []
>>> @jit
... def side_effecting(x):
...   y = x+1
...   outs.append(y)

>>> x = 1
>>> with jax.checking_leaks():
...   y = side_effecting(x)  
Traceback (most recent call last):
    ...
Exception: Leaked Trace
參數:

msg (str)