checkify 轉換#

摘要: Checkify 可讓您將可 jit 化的執行階段錯誤檢查 (例如,超出邊界索引) 新增至您的 JAX 程式碼。將 checkify.checkify 轉換與類似 assert 的 checkify.check 函數搭配使用,以將執行階段檢查新增至 JAX 程式碼

from jax.experimental import checkify
import jax
import jax.numpy as jnp

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
  y = x[i]
  z = jnp.sin(y)
  return z

jittable_f = checkify.checkify(f)

err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))

您也可以使用 checkify 自動新增常見檢查

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

err, z = checked_f(jnp.array([5, 1]), 0)
err.throw()  # if no error occurred, throw does nothing!

函數化檢查#

類似 assert 的檢查 API 本身並非函數純粹:它可能會像 assert 一樣,引發 Python 例外作為副作用。因此,它無法與 jitpmappjitscan 一起分階段執行

jax.jit(f)(jnp.ones((5,)), -1)  # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.

但是 checkify 轉換會將這些效果函數化 (或解除)。經過 checkify 轉換的函數會傳回錯誤作為新的輸出,並保持函數純粹。這種函數化表示經過 checkify 轉換的函數可以與階段/轉換任意組合

err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
..  at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
..  at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""

為什麼 JAX 需要 checkify?#

在某些 JAX 轉換下,您可以使用一般的 Python assert 表示執行階段錯誤檢查,例如,僅使用 jax.gradjax.numpy

def f(x):
  assert x > 0., "must be positive!"
  return jnp.log(x)

jax.grad(f)(0.)
# ValueError: "must be positive!"

但是,一般的 assert 在 jitpmappjitscan 內部無法運作。在這些情況下,數值計算會被分階段執行,而不是在 Python 執行期間急切地評估,因此數值無法使用

jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."

JAX 轉換語意依賴於函數純粹性,尤其是在組合多個轉換時,那麼我們如何在不中斷所有這些的情況下提供錯誤機制?除了需要新的 API 之外,情況仍然更加棘手:XLA HLO 不支援 assert 或拋出錯誤,因此即使我們有一個能夠分階段執行 assert 的 JAX API,我們又該如何將這些 assert 降低到 XLA?

您可以想像手動將執行階段檢查新增至您的函數,並將代表錯誤的值導出

def f_checked(x):
  error = x <= 0.
  result = jnp.log(x)
  return error, result

err, y = jax.jit(f_checked)(0.)
if err:
  raise ValueError("must be positive!")
# ValueError: "must be positive!"

錯誤是函數計算的常規值,錯誤在 f_checked 外部引發。f_checked 是函數純粹的,因此我們透過建構知道它可以與 jit、pmap、pjit、scan 和所有 JAX 的轉換一起運作。唯一的問題是這種導出可能很麻煩!

checkify 會為您執行此重寫:這包括透過函數導出錯誤值、將檢查重寫為布林運算,以及將結果與追蹤的錯誤值合併,並將最終錯誤值作為輸出傳回給經過 checkify 的函數

def f(x):
  checkify.check(x > 0., "{} must be positive!", x)  # convenient but effectful API
  return jnp.log(x)

f_checked = checkify(f)

err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: -1. must be positive! (check failed at <...>:2 (f))

我們將此稱為函數化或解除呼叫檢查所引入的效果。(在上面的「手動」範例中,錯誤值只是一個布林值。checkify 的錯誤值在概念上類似,但也追蹤錯誤訊息並公開 throw 和 get 方法;請參閱 jax.experimental.checkify)。checkify.check 也允許您透過將執行階段值作為格式引數提供給錯誤訊息,將執行階段值新增至您的錯誤訊息。

您現在可以手動使用執行階段檢查來檢測您的程式碼,但 checkify 也可以自動新增常見錯誤的檢查!請考慮以下錯誤情況

jnp.arange(3)[5]                # out of bounds
jnp.sin(jnp.inf)                # NaN generated
jnp.ones((5,)) / jnp.arange(5)  # division by zero

預設情況下,checkify 僅解除 checkify.check,並且不會執行任何操作來捕捉上述錯誤。但是,如果您要求它執行,checkify 也會自動使用檢查來檢測您的程式碼。

def f(x, i):
  y = x[i]        # i could be out of bounds.
  z = jnp.sin(y)  # z could become NaN
  return z

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

用於選擇要啟用哪些自動檢查的 API 是基於 Sets。如需更多詳細資訊,請參閱 jax.experimental.checkify

checkify 在 JAX 轉換下的運作方式。#

如上面的範例所示,經過 checkify 的函數可以順利地進行 jit 處理。以下是一些關於 checkify 與其他 JAX 轉換的更多範例。請注意,經過 checkify 的函數是函數純粹的,並且應該可以與所有 JAX 轉換輕鬆組合!

jit#

您可以安全地將 jax.jit 新增至經過 checkify 的函數,或對經過 jit 處理的函數進行 checkify,兩者都可以運作。

def f(x, i):
  return x[i]

checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ =  checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f)

vmap/pmap#

您可以 vmappmap 經過 checkify 的函數 (或對映射的函數進行 checkify)。映射經過 checkify 的函數將為您提供映射的錯誤,其中可能包含映射維度中每個元素的不同錯誤。

def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  return x[i]

checked_f = checkify.checkify(f, errors=checkify.all_checks)
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
errs.throw()
"""
ValueError:
  at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
  at mapped index 2: out-of-bounds indexing at <...>:3 (f)
"""

但是,checkify-of-vmap 將產生單一 (未映射) 錯誤!

@jax.vmap
def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  return x[i]

checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))

pjit#

經過 checkify 的函數的 pjit可以直接運作,您只需要為錯誤值輸出指定額外的 out_axis_resourcesNone

def f(x):
  return x / x

f = checkify.checkify(f, errors=checkify.float_checks)
f = pjit(
  f,
  in_shardings=PartitionSpec('x', None),
  out_shardings=(None, PartitionSpec('x', None)))

with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
 err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f)

grad#

如果執行 checkify-of-grad,您的梯度計算也將被檢測

def f(x):
 return x / (1 + jnp.sqrt(x))

grad_f = jax.grad(f)

err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
print(err.get())
>> nan generated by primitive mul at <...>:3 (f)

請注意,f 中沒有乘法,但其梯度計算中存在乘法 (這就是產生 NaN 的地方!)。因此,請使用 checkify-of-grad 將自動檢查新增至前向和後向傳遞操作。

checkify.checks 將僅應用於您函數的原始值。如果您想在梯度值上使用 check,請使用 custom_vjp

@jax.custom_vjp
def assert_gradient_negative(x):
 return x

def fwd(x):
 return assert_gradient_negative(x), None

def bwd(_, grad):
 checkify.check(grad < 0, "gradient needs to be negative!")
 return (grad,)

assert_gradient_negative.defvjp(fwd, bwd)

jax.grad(assert_gradient_negative)(-1.)
# ValueError: gradient needs to be negative!

jax.experimental.checkify 的優點和限制#

優點#

  • 您可以在任何地方使用它 (錯誤是「僅僅是值」,並且在轉換下像其他值一樣直觀地運作)

  • 自動檢測:您不需要對程式碼進行本地修改。相反,checkify 可以檢測所有程式碼!

限制#

  • 新增大量執行階段檢查可能會很昂貴 (例如,對每個基本運算新增 NaN 檢查會為您的計算新增大量操作)

  • 需要將錯誤值導出函數並手動拋出錯誤。如果未明確拋出錯誤,您可能會錯過錯誤!

  • 拋出錯誤值會將該錯誤值具體化在主機上,這表示它是一個封鎖操作,會破壞 JAX 的非同步預先執行。