jax.ensure_compile_time_eval#

jax.ensure_compile_time_eval()[原始碼]#

上下文管理器,確保在追蹤/編譯時期的求值(或錯誤)。

某些 JAX API,例如 jax.jit()jax.lax.scan() 涉及暫存,即延遲數值表達式(例如 jax.numpy 函數應用)的求值,以便在評估相應的 Python 表達式時,不是急切地執行這些計算,而是單獨執行它們,例如在最佳化編譯之後。但是,這種延遲可能是不希望的。例如,可能需要數值來評估 Python 控制流,因此它們的求值不能延遲。另一個例子是,為了效能原因,確保編譯時期求值(或「常數折疊」)可能是有益的。

此上下文管理器確保 JAX 計算被急切地求值。如果無法進行急切求值,則會引發 ConcretizationTypeError

這是一個虛構的例子

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  with jax.ensure_compile_time_eval():
    y = jnp.sin(3.0)
    z = jnp.sin(y)
    z_positive = z > 0
  if z_positive:  # z_positive is usable in Python control flow
    return jnp.sin(x)
  else:
    return jnp.cos(x)

這是一個來自 jax-ml/jax#3974 的真實世界例子

import jax
import jax.numpy as jnp
from jax import random

@jax.jit
def jax_fn(x):
  with jax.ensure_compile_time_eval():
    y = random.randint(random.key(0), (1000,1000), 0, 100)
  y2 = y @ y
  x2 = jnp.sum(y2) * x
  return x2

通常可以通過簡單地將常數表達式「提升」到相應的暫存 API 之外來實現類似的行為

y = random.randint(random.key(0), (1000,1000), 0, 100)

@jax.jit
def jax_fn(x):
  y2 = y @ y
  x2 = jnp.sum(y2)*x
  return x2

但在某些情況下,使用此上下文管理器可能更方便。