jax.extend.linear_util.WrappedFun#
- class jax.extend.linear_util.WrappedFun(f, f_transformed, transforms, stores, params, in_type, debug_info)[原始碼]#
表示要對其套用 transforms 的函數 f。
- 參數:
f (Callable) – 要轉換的函數。
transforms – 代表要套用至 f 的轉換的 (gen, gen_static_args) 元組列表。此處 gen 是產生器函數,而 gen_static_args 是產生器的靜態引數元組。如需產生器預期行為的描述,請參閱本模組開頭。
stores (tuple[Store | EqualStore | None, ...]) – transforms 輔助輸出的 out_store 列表。
params – 要作為關鍵字引數傳遞給 f 的額外參數,以及轉換後的關鍵字引數。
f_transformed (Callable)
debug_info (TracingDebugInfo | None)
- __init__(f, f_transformed, transforms, stores, params, in_type, debug_info)[原始碼]#
- 參數:
f (Callable)
f_transformed (Callable)
stores (tuple[Store | EqualStore | None, ...])
debug_info (TracingDebugInfo | None)
方法
__init__
(f, f_transformed, transforms, ...)call_wrapped
(*args, **kwargs)呼叫轉換後的函數
populate_stores
(stores)將值從 stores 複製到 self.stores。
wrap
(gen, gen_static_args, out_store)新增另一個轉換及其儲存。
屬性
f
f_transformed
transforms
stores
params
in_type
debug_info