jax.custom_batching.sequential_vmap#
- jax.custom_batching.sequential_vmap(f)[原始碼]#
一種使用迴圈的
custom_vmap
特例。使用
sequential_vmap
裝飾的函數在批次處理時會在迴圈中依序呼叫。這對於本身不支援批次維度的函數很有用。例如
>>> @jax.custom_batching.sequential_vmap ... def f(x): ... jax.debug.print("{}", x) ... return x + 1 ... >>> jax.vmap(f)(jnp.arange(3)) 0 1 2 Array([1, 2, 3], dtype=int32)
其中 print 語句示範了這個
vmap()
是使用迴圈產生的。請參閱
custom_vmap
的文件以取得更多詳細資訊。