checkify
轉換#
摘要: Checkify 可讓您將可 jit
化的執行階段錯誤檢查 (例如,超出邊界索引) 新增至您的 JAX 程式碼。將 checkify.checkify
轉換與類似 assert 的 checkify.check
函數搭配使用,以將執行階段檢查新增至 JAX 程式碼
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))
您也可以使用 checkify 自動新增常見檢查
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
err, z = checked_f(jnp.array([5, 1]), 0)
err.throw() # if no error occurred, throw does nothing!
函數化檢查#
類似 assert 的檢查 API 本身並非函數純粹:它可能會像 assert 一樣,引發 Python 例外作為副作用。因此,它無法與 jit
、pmap
、pjit
或 scan
一起分階段執行
jax.jit(f)(jnp.ones((5,)), -1) # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.
但是 checkify 轉換會將這些效果函數化 (或解除)。經過 checkify 轉換的函數會傳回錯誤值作為新的輸出,並保持函數純粹。這種函數化表示經過 checkify 轉換的函數可以與階段/轉換任意組合
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
.. at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
.. at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""
為什麼 JAX 需要 checkify?#
在某些 JAX 轉換下,您可以使用一般的 Python assert 表示執行階段錯誤檢查,例如,僅使用 jax.grad
和 jax.numpy
時
def f(x):
assert x > 0., "must be positive!"
return jnp.log(x)
jax.grad(f)(0.)
# ValueError: "must be positive!"
但是,一般的 assert 在 jit
、pmap
、pjit
或 scan
內部無法運作。在這些情況下,數值計算會被分階段執行,而不是在 Python 執行期間急切地評估,因此數值無法使用
jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."
JAX 轉換語意依賴於函數純粹性,尤其是在組合多個轉換時,那麼我們如何在不中斷所有這些的情況下提供錯誤機制?除了需要新的 API 之外,情況仍然更加棘手:XLA HLO 不支援 assert 或拋出錯誤,因此即使我們有一個能夠分階段執行 assert 的 JAX API,我們又該如何將這些 assert 降低到 XLA?
您可以想像手動將執行階段檢查新增至您的函數,並將代表錯誤的值導出
def f_checked(x):
error = x <= 0.
result = jnp.log(x)
return error, result
err, y = jax.jit(f_checked)(0.)
if err:
raise ValueError("must be positive!")
# ValueError: "must be positive!"
錯誤是函數計算的常規值,錯誤在 f_checked
外部引發。f_checked
是函數純粹的,因此我們透過建構知道它可以與 jit
、pmap、pjit、scan 和所有 JAX 的轉換一起運作。唯一的問題是這種導出可能很麻煩!
checkify
會為您執行此重寫:這包括透過函數導出錯誤值、將檢查重寫為布林運算,以及將結果與追蹤的錯誤值合併,並將最終錯誤值作為輸出傳回給經過 checkify 的函數
def f(x):
checkify.check(x > 0., "{} must be positive!", x) # convenient but effectful API
return jnp.log(x)
f_checked = checkify(f)
err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: -1. must be positive! (check failed at <...>:2 (f))
我們將此稱為函數化或解除呼叫檢查所引入的效果。(在上面的「手動」範例中,錯誤值只是一個布林值。checkify 的錯誤值在概念上類似,但也追蹤錯誤訊息並公開 throw 和 get 方法;請參閱 jax.experimental.checkify
)。checkify.check
也允許您透過將執行階段值作為格式引數提供給錯誤訊息,將執行階段值新增至您的錯誤訊息。
您現在可以手動使用執行階段檢查來檢測您的程式碼,但 checkify
也可以自動新增常見錯誤的檢查!請考慮以下錯誤情況
jnp.arange(3)[5] # out of bounds
jnp.sin(jnp.inf) # NaN generated
jnp.ones((5,)) / jnp.arange(5) # division by zero
預設情況下,checkify
僅解除 checkify.check
,並且不會執行任何操作來捕捉上述錯誤。但是,如果您要求它執行,checkify
也會自動使用檢查來檢測您的程式碼。
def f(x, i):
y = x[i] # i could be out of bounds.
z = jnp.sin(y) # z could become NaN
return z
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
用於選擇要啟用哪些自動檢查的 API 是基於 Sets。如需更多詳細資訊,請參閱 jax.experimental.checkify
。
checkify
在 JAX 轉換下的運作方式。#
如上面的範例所示,經過 checkify 的函數可以順利地進行 jit 處理。以下是一些關於 checkify
與其他 JAX 轉換的更多範例。請注意,經過 checkify 的函數是函數純粹的,並且應該可以與所有 JAX 轉換輕鬆組合!
jit
#
您可以安全地將 jax.jit
新增至經過 checkify 的函數,或對經過 jit 處理的函數進行 checkify
,兩者都可以運作。
def f(x, i):
return x[i]
checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ = checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f)
vmap
/pmap
#
您可以 vmap
和 pmap
經過 checkify 的函數 (或對映射的函數進行 checkify
)。映射經過 checkify 的函數將為您提供映射的錯誤,其中可能包含映射維度中每個元素的不同錯誤。
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
errs.throw()
"""
ValueError:
at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
at mapped index 2: out-of-bounds indexing at <...>:3 (f)
"""
但是,checkify-of-vmap 將產生單一 (未映射) 錯誤!
@jax.vmap
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
return x[i]
checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))
pjit
#
經過 checkify 的函數的 pjit
可以直接運作,您只需要為錯誤值輸出指定額外的 out_axis_resources
為 None
。
def f(x):
return x / x
f = checkify.checkify(f, errors=checkify.float_checks)
f = pjit(
f,
in_shardings=PartitionSpec('x', None),
out_shardings=(None, PartitionSpec('x', None)))
with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
err, data = f(input_data)
err.throw()
# ValueError: divided by zero at <...>:4 (f)
grad
#
如果執行 checkify-of-grad,您的梯度計算也將被檢測
def f(x):
return x / (1 + jnp.sqrt(x))
grad_f = jax.grad(f)
err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
print(err.get())
>> nan generated by primitive mul at <...>:3 (f)
請注意,f
中沒有乘法,但其梯度計算中存在乘法 (這就是產生 NaN 的地方!)。因此,請使用 checkify-of-grad 將自動檢查新增至前向和後向傳遞操作。
checkify.check
s 將僅應用於您函數的原始值。如果您想在梯度值上使用 check
,請使用 custom_vjp
@jax.custom_vjp
def assert_gradient_negative(x):
return x
def fwd(x):
return assert_gradient_negative(x), None
def bwd(_, grad):
checkify.check(grad < 0, "gradient needs to be negative!")
return (grad,)
assert_gradient_negative.defvjp(fwd, bwd)
jax.grad(assert_gradient_negative)(-1.)
# ValueError: gradient needs to be negative!
jax.experimental.checkify
的優點和限制#
優點#
您可以在任何地方使用它 (錯誤是「僅僅是值」,並且在轉換下像其他值一樣直觀地運作)
自動檢測:您不需要對程式碼進行本地修改。相反,
checkify
可以檢測所有程式碼!
限制#
新增大量執行階段檢查可能會很昂貴 (例如,對每個基本運算新增 NaN 檢查會為您的計算新增大量操作)
需要將錯誤值導出函數並手動拋出錯誤。如果未明確拋出錯誤,您可能會錯過錯誤!
拋出錯誤值會將該錯誤值具體化在主機上,這表示它是一個封鎖操作,會破壞 JAX 的非同步預先執行。