jax.custom_batching.custom_vmap#
- class jax.custom_batching.custom_vmap(fun)[原始碼]#
自訂 JAX 可轉換函數的 vmap 行為。
此裝飾器用於自訂 JAX 函數在
jax.vmap()
轉換下的行為。custom_vmap
裝飾的函數在大多數情況下(請參閱以下注意事項)將具有與底層函數相同的行為,除非在使用jax.vmap()
進行批次處理時。當進行批次處理時,將使用使用def_vmap()
定義的規則。例如
>>> @jax.custom_batching.custom_vmap ... def f(x, y): ... return x + y ... >>> @f.def_vmap ... def f_vmap_rule(axis_size, in_batched, xs, ys): ... assert all(in_batched) ... assert xs.shape[0] == axis_size ... assert ys.shape[0] == axis_size ... out_batched = True ... return xs * ys, out_batched ... >>> xs = jnp.arange(3) >>> ys = jnp.arange(1, 4) >>> jax.vmap(f)(xs, ys) # prints xs * ys instead of xs + ys Array([0, 2, 6], dtype=int32)
請注意,
custom_vmap
函數不支援反向模式自動微分。若要自訂 vmap 和反向模式自動微分,請將custom_vmap
與jax.custom_vjp
結合使用。例如>>> @jax.custom_vjp ... @jax.custom_batching.custom_vmap ... def f(x, y): ... return jnp.sin(x) * y ... >>> @f.def_vmap ... def f_vmap_rule(axis_size, in_batched, xs, ys): ... return jnp.cos(xs) * ys, True ... >>> def f_fwd(x, y): ... return f(x, y), (jnp.cos(x), jnp.sin(x), y) ... >>> def f_bwd(res, g): ... cos_x, sin_x, y = res ... return (cos_x * g * y, sin_x * g) ... >>> f.defvjp(f_fwd, f_bwd) >>> jax.vmap(f)(jnp.zeros(3), jnp.ones(3)) Array([1., 1., 1.], dtype=float32) >>> jax.grad(f)(jnp.zeros(()), jnp.ones(())) Array(1., dtype=float32)
請注意,
jax.custom_vjp
必須在外部,包裝custom_vmap
裝飾的函數。- 參數:
fun (Callable[..., Any])
方法
__init__
(fun)def_vmap
(vmap_rule)為此 custom_vmap 函數定義 vmap 規則。
屬性
fun
vmap_rule