jax.lax.optimization_barrier#

jax.lax.optimization_barrier(operand, /)[原始碼]#

防止編譯器跨越障礙移動操作。

最佳化障礙有多種可能的用途

  • 最佳化障礙確保所有輸入在任何依賴障礙輸出的運算子之前進行評估。這可用於強制執行特定的操作順序。

  • 最佳化障礙防止常見子表達式消除。JAX 使用此方法來實作重新具體化。

  • 最佳化障礙防止編譯器融合。也就是說,障礙之前的操作可能不會與編譯器在障礙之後的操作融合到同一個核心中。

JAX 沒有為最佳化障礙定義導數或批次處理規則。

最佳化障礙在編譯函數之外沒有效果。

參數:

operand – JAX 值的 pytree。

回傳:

JAX 值的 pytree,具有與 operand 相同的結構和內容。

範例

防止對 sin 的兩個呼叫之間進行常見子表達式消除

>>> def f(x):
...   return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x)
>>> jax.jit(f)(0.)
Array(0., dtype=float32, weak_type=True)