jax.custom_batching.custom_vmap.def_vmap#
- custom_vmap.def_vmap(vmap_rule)[原始碼]#
為此 custom_vmap 函數定義 vmap 規則。
- 參數:
vmap_rule (Callable[..., tuple[Any, Any]]) – 實作 vmap 規則的函數。此函數應接受以下引數:(1) 整數
axis_size
作為其第一個引數,(2) 與函數輸入結構相同的布林值 PyTree,指定是否批次處理每個引數,以及 (3) 批次處理的引數。它應傳回批次處理輸出的元組,以及與輸出結構相同的布林值 PyTree,指定是否批次處理每個輸出元素。請參閱jax.custom_batching.custom_vmap()
的文件以取得一些範例。- 傳回:
此方法傳遞規則,傳回未變更的
vmap_rule
。- 傳回類型:
Callable[…, tuple[Any, Any]]