jax.extend.linear_util.cache#

jax.extend.linear_util.cache(call, *, explain=None)[source]#

用於將 WrappedFun 作為第一個參數的函數的記憶化裝飾器。

參數:
  • call (Callable) – 一個 Python 可呼叫物件,它將 WrappedFun 作為其第一個參數。WrappedFun 上的底層轉換和參數用作記憶化快取鍵的一部分。

  • explain (Callable | None | None)

返回值:

call 的記憶化版本。