jax.lax.while_loop#

jax.lax.while_loop(cond_fun, body_fun, init_val)[原始碼]#

cond_fun 為 True 時,重複呼叫 body_fun 於迴圈中。

簡而言之,類 Haskell 型別簽名

while_loop :: (a -> Bool) -> (a -> a) -> a -> a

while_loop 的語意由此 Python 實作給定

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val

與該 Python 版本不同,while_loop 是 JAX primitive,並被降低為單一 WhileOp。這使其適用於減少 jit 編譯函數的編譯時間,因為 @jit 函數中的原生 Python 迴圈結構會被展開,從而導致大型 XLA 計算。

同樣與 Python 類似物不同,迴圈傳遞值 val 在所有迭代中必須保持固定的形狀和 dtype(而不僅僅是在 NumPy 秩/形狀廣播和 dtype 提升規則方面保持一致,例如)。換句話說,上面型別簽名中的型別 a 代表具有固定形狀和 dtype 的陣列(或具有固定結構和葉節點上具有固定形狀和 dtype 的陣列的巢狀 tuple/list/dict 容器資料結構)。

與使用原生 Python 迴圈結構的另一個不同之處在於,while_loop 不是反向模式可微分的,因為 XLA 計算需要在記憶體需求上具有靜態界限。

注意

while_loop() 編譯 cond_funbody_fun,因此雖然它可以與 jit() 結合使用,但通常是不必要的。

參數:
  • cond_fun (Callable[[T], BooleanNumeric]) – 型別為 a -> Bool 的函數。

  • body_fun (Callable[[T], T]) – 型別為 a -> a 的函數。

  • init_val (T) – 型別為 a 的值,一種可以是純量、陣列或任何 pytree(巢狀 Python tuple/list/dict)的型別,代表初始迴圈攜帶值。

返回:

來自 body_fun 最後一次迭代的輸出,型別為 a

返回型別:

T