jax.custom_batching.custom_vmap#

class jax.custom_batching.custom_vmap(fun)[原始碼]#

自訂 JAX 可轉換函數的 vmap 行為。

此裝飾器用於自訂 JAX 函數在 jax.vmap() 轉換下的行為。custom_vmap 裝飾的函數在大多數情況下(請參閱以下注意事項)將具有與底層函數相同的行為,除非在使用 jax.vmap() 進行批次處理時。當進行批次處理時,將使用使用 def_vmap() 定義的規則。

例如

>>> @jax.custom_batching.custom_vmap
... def f(x, y):
...   return x + y
...
>>> @f.def_vmap
... def f_vmap_rule(axis_size, in_batched, xs, ys):
...   assert all(in_batched)
...   assert xs.shape[0] == axis_size
...   assert ys.shape[0] == axis_size
...   out_batched = True
...   return xs * ys, out_batched
...
>>> xs = jnp.arange(3)
>>> ys = jnp.arange(1, 4)
>>> jax.vmap(f)(xs, ys)  # prints xs * ys instead of xs + ys
Array([0, 2, 6], dtype=int32)

請注意,custom_vmap 函數不支援反向模式自動微分。若要自訂 vmap 和反向模式自動微分,請將 custom_vmapjax.custom_vjp 結合使用。例如

>>> @jax.custom_vjp
... @jax.custom_batching.custom_vmap
... def f(x, y):
...   return jnp.sin(x) * y
...
>>> @f.def_vmap
... def f_vmap_rule(axis_size, in_batched, xs, ys):
...   return jnp.cos(xs) * ys, True
...
>>> def f_fwd(x, y):
...   return f(x, y), (jnp.cos(x), jnp.sin(x), y)
...
>>> def f_bwd(res, g):
...   cos_x, sin_x, y = res
...   return (cos_x * g * y, sin_x * g)
...
>>> f.defvjp(f_fwd, f_bwd)
>>> jax.vmap(f)(jnp.zeros(3), jnp.ones(3))
Array([1., 1., 1.], dtype=float32)
>>> jax.grad(f)(jnp.zeros(()), jnp.ones(()))
Array(1., dtype=float32)

請注意,jax.custom_vjp 必須在外部,包裝 custom_vmap 裝飾的函數。

參數:

fun (Callable[..., Any])

__init__(fun)[原始碼]#
參數:

fun (Callable[..., Any])

方法

__init__(fun)

def_vmap(vmap_rule)

為此 custom_vmap 函數定義 vmap 規則。

屬性

fun

vmap_rule