jax.custom_batching.sequential_vmap#

jax.custom_batching.sequential_vmap(f)[原始碼]#

一種使用迴圈的 custom_vmap 特例。

使用 sequential_vmap 裝飾的函數在批次處理時會在迴圈中依序呼叫。這對於本身不支援批次維度的函數很有用。

例如

>>> @jax.custom_batching.sequential_vmap
... def f(x):
...   jax.debug.print("{}", x)
...   return x + 1
...
>>> jax.vmap(f)(jnp.arange(3))
0
1
2
Array([1, 2, 3], dtype=int32)

其中 print 語句示範了這個 vmap() 是使用迴圈產生的。

請參閱 custom_vmap 的文件以取得更多詳細資訊。