jax.lax.while_loop#
- jax.lax.while_loop(cond_fun, body_fun, init_val)[原始碼]#
當
cond_fun
為 True 時,重複呼叫body_fun
於迴圈中。簡而言之,類 Haskell 型別簽名為
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
while_loop
的語意由此 Python 實作給定def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val
與該 Python 版本不同,
while_loop
是 JAX primitive,並被降低為單一 WhileOp。這使其適用於減少 jit 編譯函數的編譯時間,因為@jit
函數中的原生 Python 迴圈結構會被展開,從而導致大型 XLA 計算。同樣與 Python 類似物不同,迴圈傳遞值
val
在所有迭代中必須保持固定的形狀和 dtype(而不僅僅是在 NumPy 秩/形狀廣播和 dtype 提升規則方面保持一致,例如)。換句話說,上面型別簽名中的型別a
代表具有固定形狀和 dtype 的陣列(或具有固定結構和葉節點上具有固定形狀和 dtype 的陣列的巢狀 tuple/list/dict 容器資料結構)。與使用原生 Python 迴圈結構的另一個不同之處在於,
while_loop
不是反向模式可微分的,因為 XLA 計算需要在記憶體需求上具有靜態界限。注意
while_loop()
編譯cond_fun
和body_fun
,因此雖然它可以與jit()
結合使用,但通常是不必要的。- 參數:
cond_fun (Callable[[T], BooleanNumeric]) – 型別為
a -> Bool
的函數。body_fun (Callable[[T], T]) – 型別為
a -> a
的函數。init_val (T) – 型別為
a
的值,一種可以是純量、陣列或任何 pytree(巢狀 Python tuple/list/dict)的型別,代表初始迴圈攜帶值。
- 返回:
來自 body_fun 最後一次迭代的輸出,型別為
a
。- 返回型別:
T