jax.linear_transpose#

jax.linear_transpose(fun, *primals, reduce_axes=())[原始碼]#

轉置一個保證為線性的函數。

對於線性函數,此轉換等效於 vjp(),但避免了計算前向傳遞的開銷。

轉置函數的輸出將始終具有與 primals 完全相同的資料類型,即使某些值被截斷(例如,從複數到浮點數,或從 float64 到 float32)。為了避免截斷,請在 primals 中使用與轉置函數期望輸出的完整範圍相符的資料類型。不支援整數資料類型。

參數:
  • fun (Callable) – 要轉置的線性函數。

  • *primals – 用於評估 fun(*primals) 的形狀/資料類型的位置參數元組,可以是陣列、純量或這些類型的(巢狀)標準 Python 容器(元組、列表、字典、namedtuples,即 pytrees)。這些參數可以是實數純量/ndarray,但非必要:僅存取 shapedtype 屬性。請參閱下面的範例。(請注意,duck-typed 物件不能是 namedtuples,因為它們被視為標準 Python 容器。)

回傳:

一個可呼叫物件,用於計算 fun 的轉置。此函數的有效輸入必須具有與 fun(*primals) 結果相同的形狀/資料類型/結構。輸出將是一個元組,具有與 primals 相同的形狀/資料類型/結構。

回傳類型:

Callable

>>> import jax
>>> import types
>>>
>>> f = lambda x, y: 0.5 * x - 0.5 * y
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
>>> f_transpose(1.0)
(Array(0.5, dtype=float32), Array(-0.5, dtype=float32))