jax.custom_batching.custom_vmap.def_vmap#

custom_vmap.def_vmap(vmap_rule)[原始碼]#

為此 custom_vmap 函數定義 vmap 規則。

參數:

vmap_rule (Callable[..., tuple[Any, Any]]) – 實作 vmap 規則的函數。此函數應接受以下引數:(1) 整數 axis_size 作為其第一個引數,(2) 與函數輸入結構相同的布林值 PyTree,指定是否批次處理每個引數,以及 (3) 批次處理的引數。它應傳回批次處理輸出的元組,以及與輸出結構相同的布林值 PyTree,指定是否批次處理每個輸出元素。請參閱 jax.custom_batching.custom_vmap() 的文件以取得一些範例。

傳回:

此方法傳遞規則,傳回未變更的 vmap_rule

傳回類型:

Callable[…, tuple[Any, Any]]