jax.lax.fori_loop#
- jax.lax.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[原始碼]#
從
lower
迴圈到upper
,透過簡化為jax.lax.while_loop()
。Haskell 類型的簽名簡述如下
fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
fori_loop
的語意由以下 Python 實作給出def fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val
如同 Python 版本所示,設定
upper <= lower
將不會產生任何迭代。不支援負數或自訂增量。與 Python 版本不同,
fori_loop
是根據呼叫jax.lax.while_loop()
或呼叫jax.lax.scan()
來實作的。如果行程計數是靜態的(表示在追蹤時已知,可能是因為lower
和upper
是 Python 整數文字),則fori_loop
是根據scan()
實作的,並且支援反向模式自動微分;否則,將使用while_loop
,並且不支援反向模式自動微分。有關更多資訊,請參閱這些函式的文件字串。同樣與 Python 類似物不同,迴圈承載值
val
在所有迭代中必須保持固定的形狀和 dtype(而不僅僅是與 NumPy 秩/形狀廣播和 dtype 提升規則一致,例如)。換句話說,上面類型簽名中的類型a
表示具有固定形狀和 dtype 的陣列(或具有固定結構和葉節點上具有固定形狀和 dtype 陣列的巢狀元組/列表/字典容器資料結構)。注意
fori_loop()
編譯body_fun
,因此雖然它可以與jit()
結合使用,但通常是不必要的。- 參數:
- 傳回:
來自最後一次迭代的迴圈值,類型為
a
。