🔪 JAX - The Sharp Bits 🔪#

Open in Colab Open in Kaggle

當您在義大利鄉間漫步時,人們會毫不猶豫地告訴您 JAX 具有 “una anima di pura programmazione funzionale”

JAX 是一種用於表達組合數值程式轉換的語言。JAX 也能夠為 CPU 或加速器 (GPU/TPU) 編譯數值程式。JAX 非常適合許多數值和科學程式,但前提是它們必須以某些約束條件撰寫,我們將在下方說明。

import numpy as np
from jax import jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

🔪 純函數#

JAX 轉換和編譯旨在僅適用於功能純粹的 Python 函數:所有輸入資料都透過函數參數傳遞,所有結果都透過函數結果輸出。如果使用相同的輸入調用,純函數將始終返回相同的結果。

以下是一些非功能純粹函數的範例,JAX 對這些函數的行為與 Python 直譯器不同。請注意,這些行為並非 JAX 系統保證;使用 JAX 的正確方法是僅將其用於功能純粹的 Python 函數。

def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect
  return x

# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]
g = 0.
def impure_uses_globals(x):
  return x + g

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call:  4.0
Second call:  5.0
Third call, different type:  [14.]
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value
First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>

即使 Python 函數實際上在內部使用有狀態物件,只要它不讀取或寫入外部狀態,它仍然可以是功能純粹的

def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.))
50.0

不建議在任何您想要 jit 的 JAX 函數或任何控制流程基本運算中使用迭代器。原因是迭代器是一個 python 物件,它會引入狀態來檢索下一個元素。因此,它與 JAX 的函數式程式設計模型不相容。在下面的程式碼中,有一些嘗試將迭代器與 JAX 一起使用的不正確範例。它們大多數會返回錯誤,但有些會給出意想不到的結果。

import jax.numpy as jnp
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
45
0

🔪 原位更新#

在 Numpy 中,您習慣這樣做

numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
updated array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

但是,如果我們嘗試原位更新 JAX 裝置陣列,我們會收到錯誤! (☉_☉)

%xmode Minimal
Exception reporting mode: Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.dev.org.tw/en/latest/_autosummary/jax.numpy.ndarray.at.html

允許變數原位變更會使程式分析和轉換變得困難。JAX 要求程式是純函數。

相反地,JAX 提供使用 .at 屬性於 JAX 陣列功能性陣列更新。

️⚠️ 在 jit 編譯的程式碼和 lax.while_looplax.fori_loop 內部,切片的大小不能是引數的函數,而只能是引數形狀的函數 – 切片起始索引沒有此限制。請參閱下方的控制流程章節,以取得有關此限制的更多資訊。

陣列更新:x.at[idx].set(y)#

例如,上面的更新可以寫成

updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
updated array:
 [[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

與 NumPy 版本不同,JAX 的陣列更新函數以異地 (out-of-place) 方式運作。也就是說,更新後的陣列會作為新陣列返回,而原始陣列不會被更新修改。

print("original array unchanged:\n", jax_array)
original array unchanged:
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]

但是,在 jit 編譯的程式碼中,如果 x.at[idx].set(y)輸入值 x 未被重複使用,編譯器會最佳化陣列更新以原位發生。

使用其他運算的陣列更新#

索引陣列更新不限於僅覆寫值。例如,我們可以執行索引加法如下

print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]

有關索引陣列更新的更多詳細資訊,請參閱 .at 屬性的文件

🔪 越界索引#

在 Numpy 中,您習慣於在超出陣列邊界索引時拋出錯誤,就像這樣

np.arange(10)[11]
IndexError: index 11 is out of bounds for axis 0 with size 10

但是,從加速器上運行的程式碼引發錯誤可能很困難或不可能。因此,JAX 必須為越界索引選擇一些非錯誤行為 (類似於無效浮點運算如何導致 NaN)。當索引操作是陣列索引更新 (例如 index_add 或類似 scatter 的基本運算) 時,將會跳過越界索引的更新;當操作是陣列索引檢索 (例如 NumPy 索引或類似 gather 的基本運算) 時,由於必須返回某些內容,因此索引會被鉗制在陣列的邊界內。例如,陣列的最後一個值將從此索引操作返回

jnp.arange(10)[11]
Array(9, dtype=int32)

如果您想要更精細地控制越界索引的行為,您可以使用 ndarray.at 的可選參數;例如

jnp.arange(10.0).at[11].get()
Array(9., dtype=float32)
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Array(nan, dtype=float32)

請注意,由於索引檢索的這種行為,jnp.nanargminjnp.nanargmax 等函數對於由 NaN 組成的切片會返回 -1,而 Numpy 會拋出錯誤。

另請注意,由於上述兩種行為不是彼此的反向操作,因此反向模式自動微分 (將索引更新轉換為索引檢索,反之亦然) 將無法保留越界索引的語意。因此,將 JAX 中的越界索引視為 未定義行為 的一種情況可能是個好主意。

🔪 非陣列輸入:NumPy vs. JAX#

NumPy 通常很樂意接受 Python 列表或元組作為其 API 函數的輸入

np.sum([1, 2, 3])
np.int64(6)

JAX 偏離了這一點,通常會返回有用的錯誤

jnp.sum([1, 2, 3])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

這是一個經過深思熟慮的設計選擇,因為將列表或元組傳遞給追蹤函數可能會導致靜默的效能下降,否則可能難以檢測到。

例如,考慮以下允許列表輸入的寬鬆版本 jnp.sum

def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
permissive_sum(x)
Array(45, dtype=int32)

輸出符合我們的預期,但這隱藏了底層潛在的效能問題。在 JAX 的追蹤和 JIT 編譯模型中,Python 列表或元組中的每個元素都被視為單獨的 JAX 變數,並單獨處理並推送到裝置。這可以在上面 permissive_sum 函數的 jaxpr 中看到

make_jaxpr(permissive_sum)(x)
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
    j:i32[]. let
    k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
    l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
    n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
    o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
    q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
    r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h
    s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i
    t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j
    u:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] k
    v:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] l
    w:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] m
    x:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] n
    y:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] o
    z:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] p
    ba:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] q
    bb:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] r
    bc:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] s
    bd:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] t
    be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd
    bf:i32[] = reduce_sum[axes=(0,)] be
  in (bf,) }

列表的每個條目都作為單獨的輸入處理,導致追蹤和編譯開銷隨著列表的大小線性增長。為了防止此類意外發生,JAX 避免了將列表和元組隱式轉換為陣列。

如果您想將元組或列表傳遞給 JAX 函數,您可以先將其顯式轉換為陣列

jnp.sum(jnp.array(x))
Array(45, dtype=int32)

🔪 隨機數#

JAX 的偽隨機數生成在重要方面與 Numpy 的不同。如需快速入門指南,請參閱偽隨機數。有關更多詳細資訊,請參閱偽隨機數教學。

🔪 控制流程#

已移至使用 JIT 的控制流程與邏輯運算子

🔪 動態形狀#

jax.jitjax.vmapjax.grad 等轉換中使用的 JAX 程式碼要求所有輸出陣列和中間陣列都具有靜態形狀:也就是說,形狀不能依賴於其他陣列中的值。

例如,如果您要實作自己的 jnp.nansum 版本,您可能會從類似這樣的程式碼開始

def nansum(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  x_without_nans = x[mask]
  return x_without_nans.sum()

在 JIT 和其他轉換之外,這可以如預期般運作

x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))
10.0

如果您嘗試將 jax.jit 或其他轉換應用於此函數,它會出錯

jax.jit(nansum)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])

See https://jax.dev.org.tw/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

問題在於 x_without_nans 的大小取決於 x 中的值,這也表示其大小是動態的。通常在 JAX 中,可以透過其他方式解決對動態大小陣列的需求。例如,在這裡可以使用 jnp.where 的三引數形式將 NaN 值替換為零,從而在避免動態形狀的同時計算相同的結果

@jax.jit
def nansum_2(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  return jnp.where(mask, x, 0).sum()

print(nansum_2(x))
10.0

在其他出現動態形狀陣列的情況下,也可以使用類似的技巧。

🔪 NaNs#

除錯 NaNs#

如果您想追蹤 NaN 在您的函數或梯度中發生的位置,您可以透過以下方式開啟 NaN 檢查器

  • 設定 JAX_DEBUG_NANS=True 環境變數;

  • 在您的主檔案頂部附近加入 jax.config.update("jax_debug_nans", True)

  • jax.config.parse_flags_with_absl() 加入您的主檔案,然後使用類似 --jax_debug_nans=True 的命令列 flag 設定選項;

這將導致計算在產生 NaN 時立即出錯。開啟此選項會為 XLA 產生的每個浮點型別值新增 NaN 檢查。這表示值會被拉回主機並作為 ndarray 檢查,適用於不在 @jit 下的每個基本運算。對於 @jit 下的程式碼,會檢查每個 @jit 函數的輸出,如果存在 NaN,它將以解除最佳化的逐運算元模式重新運行該函數,從而有效地一次移除一個 @jit 層級。

可能會出現棘手的情況,例如僅在 @jit 下發生的 NaN,但在解除最佳化模式下不會產生。在這種情況下,您會看到列印出的警告訊息,但您的程式碼將繼續執行。

如果 NaN 是在梯度評估的反向傳遞中產生的,當堆疊追蹤中較上層的幾個 frame 引發例外時,您將處於 backward_pass 函數中,這本質上是一個簡單的 jaxpr 直譯器,它會反向遍歷基本運算的序列。在下面的範例中,我們使用命令列 env JAX_DEBUG_NANS=True ipython 啟動了 ipython repl,然後執行了以下程式碼

In [1]: import jax.numpy as jnp

In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

... stack trace ...

.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
    103         py_val = device_buffer.to_py()
    104         if np.any(np.isnan(py_val)):
--> 105           raise FloatingPointError("invalid value")
    106         else:
    107           return Array(device_buffer, *result_shape)

FloatingPointError: invalid value

產生的 NaN 被捕獲。透過執行 %debug,我們可以獲得事後除錯器。這也適用於 @jit 下的函數,如下面的範例所示。

In [4]: from jax import jit

In [5]: @jit
   ...: def f(x, y):
   ...:     a = x * y
   ...:     b = (x + y) / (x - y)
   ...:     c = a + 2
   ...:     return a + b * c
   ...:

In [6]: x = jnp.array([2., 0.])

In [7]: y = jnp.array([3., 0.])

In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)

 ... stack trace ...

<ipython-input-5-619b39acbaac> in f(x, y)
      2 def f(x, y):
      3     a = x * y
----> 4     b = (x + y) / (x - y)
      5     c = a + 2
      6     return a + b * c

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

 ... stack trace ...

當此程式碼在 @jit 函數的輸出中看到 NaN 時,它會調用解除最佳化的程式碼,因此我們仍然可以獲得清晰的堆疊追蹤。而且我們可以執行帶有 %debug 的事後除錯器來檢查所有值,以找出錯誤。

⚠️ 如果您沒有在除錯,則不應開啟 NaN 檢查器,因為它可能會引入大量裝置-主機往返和效能衰退!

⚠️ NaN 檢查器不適用於 pmap。若要除錯 pmap 程式碼中的 NaN,一種嘗試方法是將 pmap 替換為 vmap

🔪 倍精準度 (64 位元)#

目前,JAX 預設強制執行單精準度數字,以減輕 Numpy API 傾向於積極將運算元提升為 double 的情況。對於許多機器學習應用程式來說,這是期望的行為,但它可能會讓您感到驚訝!

x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype
/tmp/ipykernel_1169/1258726447.py:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'>  is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')

若要使用倍精準度數字,您需要在啟動時設定 jax_enable_x64 配置變數。

有幾種方法可以做到這一點

  1. 您可以透過設定環境變數 JAX_ENABLE_X64=True 來啟用 64 位元模式。

  2. 您可以手動設定啟動時的 jax_enable_x64 配置 flag

    # again, this only works on startup!
    import jax
    jax.config.update("jax_enable_x64", True)
    
  3. 您可以使用 absl.app.run(main) 來剖析命令列 flags

    import jax
    jax.config.config_with_absl()
    
  4. 如果您希望 JAX 為您執行 absl 剖析,也就是說,您不想執行 absl.app.run(main),您可以改用

    import jax
    if __name__ == '__main__':
      # calls jax.config.config_with_absl() *and* runs absl parsing
      jax.config.parse_flags_with_absl()
    

請注意,#2-#4 適用於 JAX 的任何配置選項。

然後我們可以確認 x64 模式已啟用,例如

import jax
import jax.numpy as jnp
from jax import random

jax.config.update("jax_enable_x64", True)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')

注意事項#

⚠️ XLA 不支援所有後端的 64 位元卷積!

🔪 與 NumPy 的其他差異#

雖然 jax.numpy 盡一切努力複製 numpy 的 API 行為,但確實存在行為不同的邊角案例。上述章節中詳細討論了許多此類案例;在這裡,我們列出其他幾個已知 API 不同的地方。

  • 對於二元運算,JAX 的型別提升規則與 NumPy 使用的規則略有不同。請參閱 型別提升語意 以取得更多詳細資訊。

  • 當執行不安全的型別轉換 (即目標 dtype 無法表示輸入值的轉換) 時,JAX 的行為可能取決於後端,並且通常可能與 NumPy 的行為不同。Numpy 允許透過 casting 引數控制這些情況下的結果 (請參閱 np.ndarray.astype);JAX 不提供任何此類配置,而是直接繼承 XLA:ConvertElementType 的行為。

    以下是一個不安全轉換的範例,NumPy 和 JAX 之間的結果不同

    >>> np.arange(254.0, 258.0).astype('uint8')
    array([254, 255,   0,   1], dtype=uint8)
    
    >>> jnp.arange(254.0, 258.0).astype('uint8')
    Array([254, 255, 255, 255], dtype=uint8)
    
    

    當從浮點型別轉換為整數型別或反之亦然的極端值時,通常會出現這種不匹配的情況。

🔪 教學中涵蓋的重點#

完。#

如果這裡沒有涵蓋任何讓您痛哭流涕的事情,請告訴我們,我們將擴充這些入門級的建議