jax.lax.fori_loop#

jax.lax.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[原始碼]#

lower 迴圈到 upper,透過簡化為 jax.lax.while_loop()

Haskell 類型的簽名簡述如下

fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a

fori_loop 的語意由以下 Python 實作給出

def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

如同 Python 版本所示,設定 upper <= lower 將不會產生任何迭代。不支援負數或自訂增量。

與 Python 版本不同,fori_loop 是根據呼叫 jax.lax.while_loop() 或呼叫 jax.lax.scan() 來實作的。如果行程計數是靜態的(表示在追蹤時已知,可能是因為 lowerupper 是 Python 整數文字),則 fori_loop 是根據 scan() 實作的,並且支援反向模式自動微分;否則,將使用 while_loop,並且不支援反向模式自動微分。有關更多資訊,請參閱這些函式的文件字串。

同樣與 Python 類似物不同,迴圈承載值 val 在所有迭代中必須保持固定的形狀和 dtype(而不僅僅是與 NumPy 秩/形狀廣播和 dtype 提升規則一致,例如)。換句話說,上面類型簽名中的類型 a 表示具有固定形狀和 dtype 的陣列(或具有固定結構和葉節點上具有固定形狀和 dtype 陣列的巢狀元組/列表/字典容器資料結構)。

注意

fori_loop() 編譯 body_fun,因此雖然它可以與 jit() 結合使用,但通常是不必要的。

參數:
  • lower – 一個整數,表示迴圈索引下限(包含)

  • upper – 一個整數,表示迴圈索引上限(排除)

  • body_fun – 類型為 (int, a) -> a 的函式。

  • init_val – 類型為 a 的初始迴圈承載值。

  • unroll (int | bool | None) – 一個可選的整數或布林值,用於決定迴圈的展開程度。如果提供整數,它將決定在迴圈的單次捲動迭代中執行多少次展開的迴圈迭代。如果提供布林值,它將決定迴圈是完全展開(即 unroll=True)還是完全保持未展開(即 unroll=False)。此參數僅在迴圈邊界為靜態已知時適用。

傳回:

來自最後一次迭代的迴圈值,類型為 a