jax.extend.linear_util.WrappedFun#

class jax.extend.linear_util.WrappedFun(f, f_transformed, transforms, stores, params, in_type, debug_info)[原始碼]#

表示要對其套用 transforms 的函數 f

參數:
  • f (Callable) – 要轉換的函數。

  • transforms – 代表要套用至 f 的轉換的 (gen, gen_static_args) 元組列表。此處 gen 是產生器函數,而 gen_static_args 是產生器的靜態引數元組。如需產生器預期行為的描述,請參閱本模組開頭。

  • stores (tuple[Store | EqualStore | None, ...]) – transforms 輔助輸出的 out_store 列表。

  • params – 要作為關鍵字引數傳遞給 f 的額外參數,以及轉換後的關鍵字引數。

  • f_transformed (Callable)

  • debug_info (TracingDebugInfo | None)

__init__(f, f_transformed, transforms, stores, params, in_type, debug_info)[原始碼]#
參數:
  • f (Callable)

  • f_transformed (Callable)

  • stores (tuple[Store | EqualStore | None, ...])

  • debug_info (TracingDebugInfo | None)

方法

__init__(f, f_transformed, transforms, ...)

call_wrapped(*args, **kwargs)

呼叫轉換後的函數

populate_stores(stores)

將值從 stores 複製到 self.stores

wrap(gen, gen_static_args, out_store)

新增另一個轉換及其儲存。

屬性

f

f_transformed

transforms

stores

params

in_type

debug_info