JAX 內部原理:基本運算單元#

JAX 基本運算單元簡介#

JAX 基本運算單元是 JAX 程式的基本計算單位。本文件說明 JAX 基本運算單元必須支援的介面,以允許 JAX 執行其所有轉換(這不是操作指南)。

例如,乘法加法運算可以使用低階 jax.lax.* 基本運算單元(類似 XLA 運算子包裝器)或 jax.core.Primitive("multiply_add") 來實作,如下所示。

JAX 能夠取得此類基本運算序列,並透過 Python 函數的可組合轉換(例如 jax.jit()jax.grad()jax.vmap())來轉換它們。JAX 以 *JAX 可追蹤* 的方式實作這些轉換。這表示當執行 Python 函數時,它對資料套用的唯一操作是:

  • 資料屬性檢查: 資料資訊,例如形狀或類型;或

  • JAX 基本運算單元: 這些是在本教學中涵蓋的 JAX 特殊運算。

JAX 基本運算單元知道如何對具體資料值和抽象 JAX 值進行運算。JAX 可追蹤函數 可以由 JAX 使用抽象引數來調用。例如,JAX 抽象值 — ShapedArray(float32[2,2]) — 捕捉值的類型和形狀,但不捕捉具體資料值。

JAX 轉換後的函數本身必須是 JAX 可追蹤函數,以確保這些轉換是可組合的,例如 jax.jit(jax.jacfwd(jax.grad(f)))

JAX 提供對應於大多數 XLA 運算(包括加法、矩陣乘法、正弦、餘弦和索引)的預先定義的基本運算單元。

此外,JAX 以 JAX 基本運算單元的形式提供了 NumPy 函數的實作。這表示使用 JAX 實作的 NumPy 的 Python 程式是 JAX 可追蹤的,因此是可轉換的。透過以 JAX 基本運算單元實作其他程式庫,可以使其成為 JAX 可追蹤的。

此外,JAX 基本運算單元的集合是可擴充的,因此您可以定義一個新的基本運算單元來封裝函數的行為,而不是以預先定義的 JAX 基本運算單元重新實作函數。

考慮以下範例:您想要為 JAX 新增對具有三個引數的乘法加法函數的支援,其數學定義為 multiply_add(x, y, z) = x * y + z。此函數對 3 個形狀相同的浮點值張量進行運算,並逐點執行運算。您可以透過以下方式執行此操作:

使用現有的 JAX 基本運算單元#

定義新函數的最簡單方法是以 JAX 基本運算單元或以其他使用 JAX 基本運算單元編寫的函數來編寫它們,例如在 jax.lax() 模組中定義的那些函數

from jax import lax
from jax._src import api

def multiply_add_lax(x, y, z):
  """Implementation of multiply-add using the `jax.lax` primitives."""
  return lax.add(lax.mul(x, y), z)


def square_add_lax(a, b):
  """A square-add function using the newly defined multiply-add."""
  return multiply_add_lax(a, a, b)

print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
square_add_lax =  14.0
grad(square_add_lax) =  4.0

若要瞭解 JAX 在內部如何使用基本運算單元,請新增一些追蹤函數呼叫的輔助程式

#@title Helper functions (execute this cell)
import functools
import traceback

_indentation = 0
def _trace(msg=None):
    """Print a message at current indentation."""
    if msg is not None:
        print("  " * _indentation + msg)

def _trace_indent(msg=None):
    """Print a message and then indent the rest."""
    global _indentation
    _trace(msg)
    _indentation = 1 + _indentation

def _trace_unindent(msg=None):
    """Unindent then print a message."""
    global _indentation
    _indentation = _indentation - 1
    _trace(msg)

def trace(name):
  """A decorator for functions to trace arguments and results."""

  def trace_func(func):  # pylint: disable=missing-docstring
    def pp(v):
        """Print certain values more succinctly"""
        vtype = str(type(v))
        if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
            return "<JaxComputationBuilder>"
        elif "jaxlib.xla_extension.XlaOp" in vtype:
            return "<XlaOp at 0x{:x}>".format(id(v))
        elif ("partial_eval.JaxprTracer" in vtype or
              "batching.BatchTracer" in vtype or
              "ad.JVPTracer" in vtype):
            return "Traced<{}>".format(v.aval)
        elif isinstance(v, tuple):
            return "({})".format(pp_values(v))
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])
    
    @functools.wraps(func)
    def func_wrapper(*args):
      _trace_indent("call {}({})".format(name, pp_values(args)))
      res = func(*args)
      _trace_unindent("|<- {} = {}".format(name, pp(res)))
      return res

    return func_wrapper

  return trace_func

class expectNotImplementedError(object):
  """Context manager to check for NotImplementedError."""
  def __enter__(self): pass
  def __exit__(self, type, value, tb):
    global _indentation
    _indentation = 0
    if type is NotImplementedError:
      print("\nFound expected exception:")
      traceback.print_exc(limit=3)
      return True
    elif type is None:  # No exception
      assert False, "Expected NotImplementedError"
    else:
      return False

您可以使用已經使用這些基本運算單元編寫的其他函數,例如 jax.numpy 中的函數,而不是直接使用 jax.lax() 基本運算單元

import jax.numpy as jnp
import numpy as np

@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
    return jnp.add(jnp.multiply(x, y), z)

@trace("square_add_numpy")
def square_add_numpy(a, b):
    return multiply_add_numpy(a, a, b)

print("\nNormal evaluation:")  
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
Normal evaluation:
call square_add_numpy(2.0, 10.0)
  call multiply_add_numpy(2.0, 2.0, 10.0)
  |<- multiply_add_numpy = 14.0
|<- square_add_numpy = 14.0
square_add_numpy =  14.0

Gradient evaluation:
call square_add_numpy(Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
  call multiply_add_numpy(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
  |<- multiply_add_numpy = Traced<ShapedArray(float32[], weak_type=True)>
|<- square_add_numpy = Traced<ShapedArray(float32[], weak_type=True)>
grad(square_add_numpy) =  4.0

請注意,在計算 jax.grad() 的過程中,JAX 使用特殊引數 ConcreteArray(...) (在本協作筆記本的稍後章節中說明) 調用 square_add_numpymultiply_add_numpy。請務必記住,JAX 可追蹤函數不僅必須能夠對具體引數進行運算,還必須能夠對 JAX 可能用來抽象函數執行的特殊抽象引數進行運算。

只要函數是以 JAX 基本運算單元編寫的,JAX 可追蹤性屬性就會滿足。

定義新的 JAX 基本運算單元#

新增乘法加法支援的正確方法是以現有的 JAX 基本運算單元來表示,如上所示。但是,為了示範 JAX 基本運算單元如何運作,請假裝您想要為 JAX 新增一個新的基本運算單元來實現乘法加法功能。

from jax import core

multiply_add_p = core.Primitive("multiply_add")  # Create the primitive

@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
  """The JAX-traceable way to use the JAX primitive.
  
  Note that the traced arguments must be passed as positional arguments
  to `bind`. 
  """
  return multiply_add_p.bind(x, y, z)

@trace("square_add_prim")
def square_add_prim(a, b):
  """A square-add function implemented using the new JAX-primitive."""
  return multiply_add_prim(a, a, b)
/tmp/ipykernel_1057/1751132419.py:3: DeprecationWarning: jax.core.Primitive is deprecated. Use jax.extend.core.Primitive instead, and see https://jax.dev.org.tw/en/latest/jax.extend.html for details.
  multiply_add_p = core.Primitive("multiply_add")  # Create the primitive

如果您嘗試呼叫新定義的函數,您會收到錯誤,因為您尚未告知 JAX 任何關於新基本運算單元的語義。

with expectNotImplementedError():
  square_add_prim(2., 10.)
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)

Found expected exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_1057/2844449444.py", line 2, in <module>
    square_add_prim(2., 10.)
  File "/tmp/ipykernel_1057/1393342955.py", line 48, in func_wrapper
    res = func(*args)
  File "/tmp/ipykernel_1057/1751132419.py", line 17, in square_add_prim
    return multiply_add_prim(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented

原始求值規則#

@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
  """Concrete implementation of the primitive.

  This function does not need to be JAX traceable.

  Args:
    x, y, z: The concrete arguments of the primitive. Will only be called with 
      concrete values.

  Returns:
    the concrete result of the primitive.
  """
  # Note: you can use the ordinary (non-JAX) NumPy, which is not JAX-traceable.
  return np.add(np.multiply(x, y), z)

# Now, register the primal implementation with JAX:
multiply_add_p.def_impl(multiply_add_impl)
<function __main__.multiply_add_impl(x, y, z)>
assert square_add_prim(2., 10.) == 14.
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)
    call multiply_add_impl(2.0, 2.0, 10.0)
    |<- multiply_add_impl = 14.0
  |<- multiply_add_prim = 14.0
|<- square_add_prim = 14.0

使用 jit 時會發生什麼事#

現在,如果您嘗試使用 jit,您會收到 NotImplementedError

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)

Found expected exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_1057/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 340, in cache_miss
    pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented

抽象求值規則#

為了 JIT 函數,以及其他轉換,JAX 首先僅使用引數的形狀和類型以抽象方式評估它。這種抽象評估有多個目的

  • 取得計算中使用的 JAX 基本運算單元序列。此序列將被編譯。

  • 計算計算中使用的所有向量和運算的形狀和類型。

例如,具有 3 個元素的向量的抽象可以是 ShapedArray(float32[3]),或 ConcreteArray([1., 2., 3.])。在後一種情況下,JAX 使用包裝為抽象值的實際具體值。

from jax import core

@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
  """Abstract evaluation of the primitive.

  This function does not need to be JAX traceable. It will be invoked with
  abstractions of the actual arguments

  Args:
    xs, ys, zs: Abstractions of the arguments.

  Result:
    a ShapedArray for the result of the primitive.
  """
  assert xs.shape == ys.shape
  assert xs.shape == zs.shape
  return core.ShapedArray(xs.shape, xs.dtype)

# Now, register the abstract evaluation with JAX:
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
<function __main__.multiply_add_abstract_eval(xs, ys, zs)>

如果您重新嘗試套用 jit,您可以檢查抽象評估如何進行,但您會收到另一個關於缺少實際 XLA 編譯規則的錯誤

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>

Found expected exception:
Traceback (most recent call last):
  File "/home/docs/.asdf/installs/python/3.10.15/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/docs/.asdf/installs/python/3.10.15/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/ipykernel_1057/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 340, in cache_miss
    pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu

XLA 編譯規則#

JAX 編譯的工作原理是將每個基本運算單元編譯成 XLA 運算圖。

這是為 JAX 新增新功能的最大障礙,因為 XLA 運算的集合是有限的,而且 JAX 已經為它們中的大多數預先定義了基本運算單元。但是,XLA 包含一個 CustomCall 運算,可用於封裝使用 C++ 定義的任意功能。

from jax._src.lib.mlir.dialects import hlo

@trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
  """The compilation to XLA of the primitive.

  Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
  the results of the function.

  Does not need to be a JAX-traceable function.
  """
  return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]

# Now, register the lowering rule with JAX.
# For GPU, refer to the https://jax.dev.org.tw/en/latest/Custom_Operation_for_GPUs.html
from jax.interpreters import mlir

mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
<function __main__.multiply_add_lowering(ctx, xc, yc, zc)>

您現在將成功套用 jax.jit。請注意,JAX 首先以抽象方式評估函數,這會觸發 multiply_add_abstract_eval 函數,然後編譯它遇到的基本運算單元集,包括 multiply_add。此時,JAX 調用 multiply_add_lowering

assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7fb25450d1c0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fb255765580>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fb2544ff910>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fb2557a1c50>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7fb258145080>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fb2557a1ba0>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56102ed26c00>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1057/1570919344.py":1:0) at callsite("<module>"("/tmp/ipykernel_1057/1570919344.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fb2580fbaa0, file "/tmp/ipykernel_1057/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0)), (<code object func_wrapper at 0x7fb2580face0, file "/tmp/ipykernel_1057/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0)), (<code object square_add_prim at 0x7fb2580fb730, file "/tmp/ipykernel_1057/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0)), (<code object <lambda> at 0x7fb2557ac450, file "/tmp/ipykernel_1057/1570919344.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1057/1570919344.py":1:0)), (<code object <module> at 0x7fb2557ac710, file "/tmp/ipykernel_1057/1570919344.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1057/1570919344.py":1:0)), (<code object run_code at 0x7fb29538b050, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7fb29538aef0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7fb29538ab80, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7fb295255790, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 119>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0))}, canonical_name_cache={'/tmp/ipykernel_1057/1751132419.py': '/tmp/ipykernel_1057/1751132419.py', '/tmp/ipykernel_1057/1393342955.py': '/tmp/ipykernel_1057/1393342955.py', '/tmp/ipykernel_1057/1570919344.py': '/tmp/ipykernel_1057/1570919344.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1057/1751132419.py': True, '/tmp/ipykernel_1057/1393342955.py': True, '/tmp/ipykernel_1057/1570919344.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fb2557a18a0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=None, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fb25574e870>]

以下是 jit 的另一個用法,您僅針對第一個引數進行編譯。請注意,square_add_prim 的第二個引數是具體的,這導致 multiply_add_abstract_eval 的第三個引數為 ConcreteArray。請注意,multiply_add_abstract_eval 可以與 ShapedArrayConcreteArray 一起使用。

assert api.jit(lambda x, y: square_add_prim(x, y), 
               static_argnums=1)(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, 10.0)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, 10.0)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7fb25450d8c0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fb2545402c0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fb2545403f0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fb2557a32d0>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7fb258145080>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fb2557a3370>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56102ed26c00>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1057/4165789807.py":1:0) at callsite("<module>"("/tmp/ipykernel_1057/4165789807.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fb2580fbaa0, file "/tmp/ipykernel_1057/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0)), (<code object func_wrapper at 0x7fb2580face0, file "/tmp/ipykernel_1057/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0)), (<code object square_add_prim at 0x7fb2580fb730, file "/tmp/ipykernel_1057/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0)), (<code object <lambda> at 0x7fb2557ae550, file "/tmp/ipykernel_1057/4165789807.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1057/4165789807.py":1:0)), (<code object <module> at 0x7fb2557ae8c0, file "/tmp/ipykernel_1057/4165789807.py", line 1>, 20): loc("<module>"("/tmp/ipykernel_1057/4165789807.py":1:0)), (<code object run_code at 0x7fb29538b050, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7fb29538aef0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7fb29538ab80, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7fb295255790, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 119>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0))}, canonical_name_cache={'/tmp/ipykernel_1057/1751132419.py': '/tmp/ipykernel_1057/1751132419.py', '/tmp/ipykernel_1057/1393342955.py': '/tmp/ipykernel_1057/1393342955.py', '/tmp/ipykernel_1057/4165789807.py': '/tmp/ipykernel_1057/4165789807.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1057/1751132419.py': True, '/tmp/ipykernel_1057/1393342955.py': True, '/tmp/ipykernel_1057/4165789807.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fb2557a3be0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=None, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fb254548b70>]

前向微分#

JAX 以 Jacobian-Vector Product (JVP) 的形式實作前向微分(您可以在進階自動微分中瞭解更多資訊)。

如果您嘗試計算 jvp 函數,您會收到錯誤,因為您尚未告知 JAX 如何微分 multiply_add 基本運算單元。

# The second argument is set to `(2., 10.)` values where you
# evaluate the Jacobian, and the third argument `(1., 1.)`
# contains the values of the tangents for the arguments.
with expectNotImplementedError():
  api.jvp(square_add_prim, (2., 10.), (1., 1.))
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)

Found expected exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_1057/459539105.py", line 5, in <module>
    api.jvp(square_add_prim, (2., 10.), (1., 1.))
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1700, in jvp
    return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1729, in _jvp
    out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
NotImplementedError: Differentiation rule for 'multiply_add' not implemented
from jax.interpreters import ad

@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
  """Evaluates the primal output and the tangents (Jacobian-vector product).

  Given values of the arguments and perturbation of the arguments (tangents), 
  compute the output of the primitive and the perturbation of the output.

  This method must be JAX-traceable. JAX may invoke it with abstract values 
  for the arguments and tangents.

  Args:
    arg_values: A tuple of arguments
    arg_tangents: A tuple with the tangents of the arguments. The tuple has 
      the same length as the arg_values. Some of the tangents may also be the 
      special value `ad.Zero` to specify a zero tangent

  Returns:
     A pair of the primal output and the tangent.
  """
  x, y, z = arg_values
  xt, yt, zt = arg_tangents
  _trace("Primal evaluation:")
  # Now, you have a JAX-traceable computation of the output. 
  # Normally, you can use the multiply add (`ma`) primitive itself to compute the primal output. 
  primal_out = multiply_add_prim(x, y, z)

  _trace("Tangent evaluation:")
  # You must use a JAX-traceable way to compute the tangent. It turns out that 
  # the output tangent can be computed as (xt * y + x * yt + zt),
  # which you can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.

  # You do need to deal specially with `Zero`. Here, you just turn it into a 
  # proper tensor of 0s (of the same shape as 'x'). 
  # An alternative would be to check for `Zero` and perform algebraic 
  # simplification of the output tangent computation.
  def make_zero(tan):
    return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan  

  output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
  return (primal_out, output_tangent)

# Register the forward differentiation rule with JAX:
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, 1.0, 1.0)
        call multiply_add_impl(2.0, 1.0, 1.0)
        |<- multiply_add_impl = 3.0
      |<- multiply_add_prim = 3.0
      call multiply_add_prim(1.0, 2.0, 3.0)
        call multiply_add_impl(1.0, 2.0, 3.0)
        |<- multiply_add_impl = 5.0
      |<- multiply_add_prim = 5.0
    |<- multiply_add_value_and_jvp = (14.0, 5.0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>

前向微分的 JIT#

您可以將 jit 套用至前向微分函數

assert api.jit(lambda arg_values, arg_tangents: 
                   api.jvp(square_add_prim, arg_values, arg_tangents))(
         (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
    call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>), (Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>))
      Primal evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
      Tangent evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
    |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7fb25457adc0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fb254582930>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fb2545828d0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fb2557a28e0>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7fb258145080>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fb2557a2920>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56102ed269d0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1057/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1057/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fb2580fbaa0, file "/tmp/ipykernel_1057/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0)), (<code object func_wrapper at 0x7fb2580face0, file "/tmp/ipykernel_1057/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fb2557addc0, file "/tmp/ipykernel_1057/347789876.py", line 3>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":27:0)), (<code object square_add_prim at 0x7fb2580fb730, file "/tmp/ipykernel_1057/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0)), (<code object <lambda> at 0x7fb2557ae760, file "/tmp/ipykernel_1057/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1057/2145028508.py":2:0)), (<code object <module> at 0x7fb2557ad840, file "/tmp/ipykernel_1057/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1057/2145028508.py":1:0))}, canonical_name_cache={'/tmp/ipykernel_1057/1751132419.py': '/tmp/ipykernel_1057/1751132419.py', '/tmp/ipykernel_1057/1393342955.py': '/tmp/ipykernel_1057/1393342955.py', '/tmp/ipykernel_1057/347789876.py': '/tmp/ipykernel_1057/347789876.py', '/tmp/ipykernel_1057/2145028508.py': '/tmp/ipykernel_1057/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1057/1751132419.py': True, '/tmp/ipykernel_1057/1393342955.py': True, '/tmp/ipykernel_1057/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1057/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fb2545902e0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=None, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fb255761bb0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7fb25457adc0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fb254582930>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fb2545828d0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fb2557a28e0>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7fb258145080>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fb2557a2920>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56102ed269d0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1057/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1057/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56102edd0550>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1057/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1057/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fb2580fbaa0, file "/tmp/ipykernel_1057/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0)), (<code object func_wrapper at 0x7fb2580face0, file "/tmp/ipykernel_1057/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fb2557addc0, file "/tmp/ipykernel_1057/347789876.py", line 3>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":27:0)), (<code object square_add_prim at 0x7fb2580fb730, file "/tmp/ipykernel_1057/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0)), (<code object <lambda> at 0x7fb2557ae760, file "/tmp/ipykernel_1057/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1057/2145028508.py":2:0)), (<code object <module> at 0x7fb2557ad840, file "/tmp/ipykernel_1057/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1057/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7fb2557addc0, file "/tmp/ipykernel_1057/347789876.py", line 3>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1057/1751132419.py': '/tmp/ipykernel_1057/1751132419.py', '/tmp/ipykernel_1057/1393342955.py': '/tmp/ipykernel_1057/1393342955.py', '/tmp/ipykernel_1057/347789876.py': '/tmp/ipykernel_1057/347789876.py', '/tmp/ipykernel_1057/2145028508.py': '/tmp/ipykernel_1057/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1057/1751132419.py': True, '/tmp/ipykernel_1057/1393342955.py': True, '/tmp/ipykernel_1057/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1057/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fb254590430>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=None, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 3))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fb25804f470>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7fb25457adc0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fb254582930>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fb2545828d0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fb2557a28e0>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7fb258145080>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fb2557a2920>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56102ed269d0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1057/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1057/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56102edd0550>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1057/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1057/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56102edb2c60>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1057/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1057/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fb2580fbaa0, file "/tmp/ipykernel_1057/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0)), (<code object func_wrapper at 0x7fb2580face0, file "/tmp/ipykernel_1057/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fb2557addc0, file "/tmp/ipykernel_1057/347789876.py", line 3>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":27:0)), (<code object square_add_prim at 0x7fb2580fb730, file "/tmp/ipykernel_1057/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0)), (<code object <lambda> at 0x7fb2557ae760, file "/tmp/ipykernel_1057/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1057/2145028508.py":2:0)), (<code object <module> at 0x7fb2557ad840, file "/tmp/ipykernel_1057/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1057/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7fb2557addc0, file "/tmp/ipykernel_1057/347789876.py", line 3>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":41:0)), (<code object multiply_add_value_and_jvp at 0x7fb2557addc0, file "/tmp/ipykernel_1057/347789876.py", line 3>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1057/1751132419.py': '/tmp/ipykernel_1057/1751132419.py', '/tmp/ipykernel_1057/1393342955.py': '/tmp/ipykernel_1057/1393342955.py', '/tmp/ipykernel_1057/347789876.py': '/tmp/ipykernel_1057/347789876.py', '/tmp/ipykernel_1057/2145028508.py': '/tmp/ipykernel_1057/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1057/1751132419.py': True, '/tmp/ipykernel_1057/1393342955.py': True, '/tmp/ipykernel_1057/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1057/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fb254590490>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=None, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%3 = "stablehlo.add"(%2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fb2545878b0>]

請注意,首先,您以抽象方式評估 multiply_add_value_and_jvp,這反過來以抽象方式評估原始求值和切線求值(總共調用 3 次 ma 基本運算單元)。然後,您編譯基本運算單元的 3 個實例。

反向微分#

如果您現在嘗試使用反向微分,您會注意到 JAX 首先使用 multiply_add_value_and_jvp 來計算抽象值的前向微分,但隨後遇到 NotImplementedError

在計算反向微分時,JAX 首先對前向微分程式碼 multiply_add_value_and_jvp 執行抽象評估,以取得計算輸出切線的基本運算單元追蹤。

  • 請注意,JAX 使用微分點的具體值和切線的抽象值執行此抽象評估。

  • 請注意,JAX 對應於 ma 的第三個引數使用特殊的抽象切線值 Zero。這反映了您沒有針對 square_add_prim 的第二個引數進行微分,該引數會流向 multiply_add_prim 的第三個引數。

  • 另請注意,在切線的抽象評估期間,您將值 0.0 作為第三個引數的切線傳遞。這是因為在 multiply_add_value_and_jvp 的定義中使用了 make_zero 函數。

# This is reverse differentiation w.r.t. the first argument of `square_add_prim`
with expectNotImplementedError():
  api.grad(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
Found expected exception:
Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 391, in get_primitive_transpose
    return primitive_transposes[p]
KeyError: multiply_add

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/docs/.asdf/installs/python/3.10.15/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/docs/.asdf/installs/python/3.10.15/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/ipykernel_1057/2155094905.py", line 3, in <module>
    api.grad(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 396, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented

上述錯誤是因為 JAX 缺少使用前向微分程式碼來計算反向微分的功能。

轉置#

如先前所述,在計算反向微分時,JAX 取得使用前向微分計算切線的基本運算單元追蹤。然後,JAX 以抽象方式向後解釋此追蹤,並為每個基本運算單元套用轉置規則

若要瞭解發生了什麼事,請考慮函數 f(x, y) = x * y + y 的更簡單範例。假設您需要在點 (2., 4.) 處進行微分。JAX 將從輸入 xtyt 的切線產生 ft 的以下 JVP 切線計算

   a = xt * 4.
   b = 2. * yt
   c = a + b
   ft = c + yt

透過建構,切線計算始終在輸入切線中呈線性。切線計算中可能出現的唯一非線性運算子是乘法,但其中一個運算元是常數。

JAX 將透過向後處理 JVP 計算來產生反向微分計算。對於切線計算中的每個運算,它會使用運算結果的餘切來累積運算使用的變數的餘切

  # Initialize cotangents of inputs and intermediate variables:
  xct = yct = act = bct = cct = 0.
  # Initialize cotangent of the output:
  fct = 1.
  # Process `ft = c + yt`:
  cct += fct
  yct += fct
  # Process `c = a + b`:
  act += cct
  bct += cct
  # Process `b = 2. * yt`:
  yct += 2. * bct
  # Process `a = xt * 4.`:
  xct += act * 4.

可以驗證此計算產生 xct = 4.yct = 3.,它們是函數 f 的偏導數。

對於可能出現在 JVP 計算中的每個基本運算單元,JAX 都知道如何轉置它。從概念上講,如果基本運算單元 p(x, y, z) 在引數 yz 中對於 x 的常數值是線性的,例如,p(x, y, z) = y*cy + z*cz,則基本運算單元的轉置是

p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)

請注意,p_transpose 採用基本運算單元輸出的餘切和對應於基本運算單元每個引數的值。對於線性引數,轉置取得未定義的 _ 值,而對於其他引數,它取得實際常數。轉置為基本運算單元的每個引數傳回餘切值,為常數引數傳回值 None

特別是

 add_transpose(out_ct, _, _) = (out_ct, out_ct)
 mult_transpose(out_ct, x, _) = (None, x * out_ct)
 mult_transpose(out_ct, _, y) = (out_ct * y, None)
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
  """Evaluates the transpose of a linear primitive.

  This method is only used when computing the backward gradient following 
  `value_and_jvp`, and is only needed for primitives that are used in the JVP 
  calculation for some other primitive. You need a transposition for `multiply_add_prim`, 
  because you have used `multiply_add_prim` in the computation of the `output_tangent` in 
  `multiply_add_value_and_jvp`.

  In this case, multiply_add is not a linear primitive. However, it is used linearly 
  w.r.t. tangents in `multiply_add_value_and_jvp`:
       `output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))`.

  Always one of the first two multiplicative arguments is a constant.

  Args:
      ct: The cotangent of the output of the primitive.
      x, y, z: The values of the arguments. The arguments that are used linearly
        get an ad.UndefinedPrimal value. The other arguments get a constant
        value.

  Returns:
      A tuple with the cotangent of the inputs, with the value None
      corresponding to the constant arguments.
  """
  if not ad.is_undefined_primal(x):
    # This use of multiply_add is with a constant "x".
    assert ad.is_undefined_primal(y)
    ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
    res = None, ct_y, ct
  else:
    # This use of multiply_add is with a constant "y".
    assert ad.is_undefined_primal(x)
    ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
    res = ct_x, None, ct
  return res

ad.primitive_transposes[multiply_add_p] = multiply_add_transpose

現在您可以完成 grad 的執行

assert api.grad(square_add_prim)(2., 10.) == 4.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, 10.0)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 2.0, UndefinedPrimal(ShapedArray(float32[])))
  call multiply_add_prim(1.0, 2.0, 0.0)
    call multiply_add_impl(1.0, 2.0, 0.0)
    |<- multiply_add_impl = 2.0
  |<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (2.0, None, 1.0)
call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 0.0)
  call multiply_add_prim(2.0, 1.0, 0.0)
    call multiply_add_impl(2.0, 1.0, 0.0)
    |<- multiply_add_impl = 2.0
  |<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (None, 2.0, 1.0)

請注意對 multiply_add_transpose 的兩個呼叫。它們對應於在 multiply_add_value_and_jvpoutput_tangent 計算中對 multiply_add_prim 的兩個使用。對轉置的第一個呼叫對應於 multiply_add_prim 的最後一個使用:multiply_add_prim(xt, y, ...),其中 y 是常數 2.0

反向微分的 JIT#

請注意,multiply_add_value_and_jvp 的抽象評估僅使用抽象值。同時,在沒有 JIT 的情況下,您使用了 ConcreteArray

assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
    call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
      Tangent evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, UndefinedPrimal(ShapedArray(float32[])))
  call multiply_add_prim(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
    call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
|<- multiply_add_transpose = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>)
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
|<- multiply_add_transpose = (None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>)
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7fb2545a7440>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fb25459e520>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fb25459f640>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fb254591500>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7fb258145080>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fb254592170>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56102edb75c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1057/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fb2580fbaa0, file "/tmp/ipykernel_1057/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0)), (<code object func_wrapper at 0x7fb2580face0, file "/tmp/ipykernel_1057/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fb2557addc0, file "/tmp/ipykernel_1057/347789876.py", line 3>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":41:0)), (<code object square_add_prim at 0x7fb2580fb730, file "/tmp/ipykernel_1057/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0)), (<code object <module> at 0x7fb2557acbe0, file "/tmp/ipykernel_1057/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1057/3085343041.py":1:0)), (<code object run_code at 0x7fb29538b050, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_1057/1751132419.py': '/tmp/ipykernel_1057/1751132419.py', '/tmp/ipykernel_1057/1393342955.py': '/tmp/ipykernel_1057/1393342955.py', '/tmp/ipykernel_1057/347789876.py': '/tmp/ipykernel_1057/347789876.py', '/tmp/ipykernel_1057/3085343041.py': '/tmp/ipykernel_1057/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1057/1751132419.py': True, '/tmp/ipykernel_1057/1393342955.py': True, '/tmp/ipykernel_1057/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1057/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fb254590e20>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=None, xla_metadata={}), platforms=None), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fb254587d70>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7fb2545a7440>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fb25459e520>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fb25459f640>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fb254591500>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7fb258145080>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fb254592170>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56102edb75c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1057/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56102ee5eec0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1057/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fb2580fbaa0, file "/tmp/ipykernel_1057/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0)), (<code object func_wrapper at 0x7fb2580face0, file "/tmp/ipykernel_1057/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7fb2557addc0, file "/tmp/ipykernel_1057/347789876.py", line 3>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":41:0)), (<code object square_add_prim at 0x7fb2580fb730, file "/tmp/ipykernel_1057/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0)), (<code object <module> at 0x7fb2557acbe0, file "/tmp/ipykernel_1057/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1057/3085343041.py":1:0)), (<code object run_code at 0x7fb29538b050, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object multiply_add_value_and_jvp at 0x7fb2557addc0, file "/tmp/ipykernel_1057/347789876.py", line 3>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1057/347789876.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1057/1751132419.py': '/tmp/ipykernel_1057/1751132419.py', '/tmp/ipykernel_1057/1393342955.py': '/tmp/ipykernel_1057/1393342955.py', '/tmp/ipykernel_1057/347789876.py': '/tmp/ipykernel_1057/347789876.py', '/tmp/ipykernel_1057/3085343041.py': '/tmp/ipykernel_1057/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1057/1751132419.py': True, '/tmp/ipykernel_1057/1393342955.py': True, '/tmp/ipykernel_1057/347789876.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1057/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fb2545927d0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=None, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fb25874dab0>]

批次處理#

批次處理轉換採用逐點計算,並將其轉換為向量計算。如果您現在嘗試它,您將收到 NotImplementedError

# The arguments are two vectors instead of two scalars.
with expectNotImplementedError():
  api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
                                               np.array([10., 20.]))
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)

Found expected exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_1057/1080163607.py", line 3, in <module>
    api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1001, in vmap_f
    out_flat = batching.batch(
NotImplementedError: Batching rule for 'multiply_add' not implemented

您需要指示 JAX 如何評估基本運算單元的批次處理版本。在這種特殊情況下,multiply_add_prim 已經針對輸入向量的任何維度逐點運算,因此批次處理版本可以使用相同的 multiply_add_prim 實作。

from jax.interpreters import batching

@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
  """Computes the batched version of the primitive.
  
  This must be a JAX-traceable function.
  
  Since the `multiply_add primitive` already operates point-wise on arbitrary
  dimension tensors, to batch it you can use the primitive itself. This works as
  long as both the inputs have the same dimensions and are batched along the
  same axes. The result is batched along the axis that the inputs are batched.

  Args:
    vector_arg_values: A tuple of two arguments, each being a tensor of matching
      shape.
    batch_axes: The axes that are being batched. See vmap documentation.

  Returns:
    A tuple of the result, and the result axis that was batched. 
  """
  assert batch_axes[0] == batch_axes[1]
  assert batch_axes[0] == batch_axes[2]
  _trace("Using multiply_add to compute the batch:")
  res = multiply_add_prim(*vector_arg_values)
  return res, batch_axes[0]


batching.primitive_batchers[multiply_add_p] = multiply_add_batch
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
  np.array([2., 3.]),
  np.array([10., 20.])),
  [14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
    call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))
      Using multiply_add to compute the batch:
      call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])
        call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])
        |<- multiply_add_impl = [14. 29.]
      |<- multiply_add_prim = [14. 29.]
    |<- multiply_add_batch = ([14. 29.], 0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>

批次處理的 JIT#

以下是將 JIT 套用至批次處理的範例

assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
                    (np.array([2., 3.]),
                     np.array([10., 20.])),
                    [14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
    call multiply_add_batch((Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>), (0, 0, 0))
      Using multiply_add to compute the batch:
      call multiply_add_prim(Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>)
        call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[2])
      |<- multiply_add_prim = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>
    |<- multiply_add_batch = (Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace>, 0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7fb2545a79c0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7fb2545c4040>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7fb2545c40d0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7fb254591230>, platforms=('cpu',), backend=<jaxlib.xla_extension.Client object at 0x7fb258145080>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7fb254591db0>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56102f1d6620>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_batch"("/tmp/ipykernel_1057/1827752256.py":25:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0) at callsite("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1057/1392464762.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7fb2580fbaa0, file "/tmp/ipykernel_1057/1751132419.py", line 5>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1057/1751132419.py":12:0)), (<code object func_wrapper at 0x7fb2580face0, file "/tmp/ipykernel_1057/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1057/1393342955.py":48:0)), (<code object multiply_add_batch at 0x7fb2557af260, file "/tmp/ipykernel_1057/1827752256.py", line 3>, 52): loc("multiply_add_batch"("/tmp/ipykernel_1057/1827752256.py":25:0)), (<code object square_add_prim at 0x7fb2580fb730, file "/tmp/ipykernel_1057/1751132419.py", line 14>, 8): loc("square_add_prim"("/tmp/ipykernel_1057/1751132419.py":17:0)), (<code object <module> at 0x7fb2557aee40, file "/tmp/ipykernel_1057/1392464762.py", line 1>, 48): loc("<module>"("/tmp/ipykernel_1057/1392464762.py":1:0)), (<code object run_code at 0x7fb29538b050, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_1057/1751132419.py': '/tmp/ipykernel_1057/1751132419.py', '/tmp/ipykernel_1057/1393342955.py': '/tmp/ipykernel_1057/1393342955.py', '/tmp/ipykernel_1057/1827752256.py': '/tmp/ipykernel_1057/1827752256.py', '/tmp/ipykernel_1057/1392464762.py': '/tmp/ipykernel_1057/1392464762.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1057/1751132419.py': True, '/tmp/ipykernel_1057/1393342955.py': True, '/tmp/ipykernel_1057/1827752256.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/batching.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1057/1392464762.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='vmap'))), primitive=multiply_add, avals_in=[ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2])], avals_out=[ShapedArray(float32[2])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7fb25575acb0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=None, xla_metadata={}), platforms=None), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fb25874e030>]