jax.lax.scan#

jax.lax.scan(f, init, xs=None, length=None, reverse=False, unroll=1, _split_transpose=False)[原始碼]#

掃描函數在領先陣列軸上進行,同時攜帶狀態。

簡要的 類似 Haskell 的型別簽名

scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])

其中對於任何陣列型別規範 t[t] 表示具有額外領先軸的型別,如果 t 是具有陣列葉節點的 pytree(容器)型別,則 [t] 表示具有相同 pytree 結構和相應葉節點的型別,每個葉節點都有一個額外的領先軸。

xs 的型別(上面表示為 a)是陣列型別或 None,並且 ys 的型別(上面表示為 b)是陣列型別時,scan() 的語意大致由這個 Python 實作給出

def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

與該 Python 版本不同,xsys 都可能是任意 pytree 值,因此可以一次掃描多個陣列並產生多個輸出陣列。None 實際上是這種情況的一個特例,因為它代表一個空的 pytree。

同樣與該 Python 版本不同,scan() 是一個 JAX primitive,並被降低為單個 WhileOp。這使其對於減少 JIT 編譯函數的編譯時間非常有用,因為 jit() 函數中的原生 Python 迴圈結構會被展開,從而導致大型 XLA 計算。

最後,迴圈攜帶值 carry 在所有迭代中必須保持固定的形狀和 dtype(而不僅僅是在 NumPy 秩/形狀廣播和 dtype 提升規則下保持一致,例如)。換句話說,上面型別簽名中的型別 c 表示具有固定形狀和 dtype 的陣列(或具有固定結構和位於葉節點的具有固定形狀和 dtype 的陣列的巢狀 tuple/list/dict 容器資料結構)。

注意

scan() 編譯 f,因此雖然它可以與 jit() 結合使用,但通常是不必要的。

參數:
  • f (Callable[[Carry, X], tuple[Carry, Y]]) – 要掃描的 Python 函數,型別為 c -> a -> (c, b),表示 f 接受兩個引數,其中第一個是迴圈攜帶值,第二個是 xs 沿其領先軸的切片,並且 f 回傳一個 pair,其中第一個元素表示迴圈攜帶值的新值,第二個元素表示輸出的切片。

  • init (Carry) – 初始迴圈攜帶值,型別為 c,可以是純量、陣列或其任何 pytree(巢狀 Python tuple/list/dict),表示初始迴圈攜帶值。此值必須與 f 回傳的 pair 的第一個元素具有相同的結構。

  • xs (X | None) – 要沿著領先軸掃描的值,型別為 [a],其中 [a] 可以是陣列或任何 pytree(巢狀 Python tuple/list/dict),並具有一致的領先軸大小。

  • length (int | None) – 可選整數,指定迴圈迭代次數,必須與 xs 中陣列的領先軸大小一致(但可用於執行不需要輸入 xs 的掃描)。

  • reverse (bool) – 可選布林值,指定是向前(預設)還是向後執行掃描迭代,相當於反轉 xsys 中陣列的領先軸。

  • unroll (int | bool) – 可選正整數或布林值,用於指定 scan primitive 的底層操作中,在單個迴圈迭代中展開多少次掃描迭代。如果提供整數,它將決定在迴圈的單個滾動迭代中執行多少次展開的迴圈迭代。如果提供布林值,它將決定迴圈是完全展開(即 unroll=True)還是完全保持未展開(即 unroll=False)。

  • _split_transpose (bool) – 實驗性的可選布林值,用於指定是否將轉置進一步拆分為掃描(計算激活梯度)和映射(計算對應於陣列引數的梯度)。啟用此功能可能會增加記憶體需求,因此這是一個實驗性功能,可能會發展甚至被撤回。

回傳值:

型別為 (c, [b]) 的 pair,其中第一個元素表示最終迴圈攜帶值,第二個元素表示掃描輸入的領先軸時 f 的第二個輸出的堆疊輸出。

回傳型別:

tuple[Carry, Y]