jax.ensure_compile_time_eval#
- jax.ensure_compile_time_eval()[原始碼]#
上下文管理器,確保在追蹤/編譯時期的求值(或錯誤)。
某些 JAX API,例如
jax.jit()
和jax.lax.scan()
涉及暫存,即延遲數值表達式(例如jax.numpy
函數應用)的求值,以便在評估相應的 Python 表達式時,不是急切地執行這些計算,而是單獨執行它們,例如在最佳化編譯之後。但是,這種延遲可能是不希望的。例如,可能需要數值來評估 Python 控制流,因此它們的求值不能延遲。另一個例子是,為了效能原因,確保編譯時期求值(或「常數折疊」)可能是有益的。此上下文管理器確保 JAX 計算被急切地求值。如果無法進行急切求值,則會引發
ConcretizationTypeError
。這是一個虛構的例子
import jax import jax.numpy as jnp @jax.jit def f(x): with jax.ensure_compile_time_eval(): y = jnp.sin(3.0) z = jnp.sin(y) z_positive = z > 0 if z_positive: # z_positive is usable in Python control flow return jnp.sin(x) else: return jnp.cos(x)
這是一個來自 jax-ml/jax#3974 的真實世界例子
import jax import jax.numpy as jnp from jax import random @jax.jit def jax_fn(x): with jax.ensure_compile_time_eval(): y = random.randint(random.key(0), (1000,1000), 0, 100) y2 = y @ y x2 = jnp.sum(y2) * x return x2
通常可以通過簡單地將常數表達式「提升」到相應的暫存 API 之外來實現類似的行為
y = random.randint(random.key(0), (1000,1000), 0, 100) @jax.jit def jax_fn(x): y2 = y @ y x2 = jnp.sum(y2)*x return x2
但在某些情況下,使用此上下文管理器可能更方便。