Omnistaging#
mattjj@ 2020 年 9 月 25 日
這更像是一份升級指南,而不是設計文件。
目錄#
太長不看 (tl;dr)#
發生什麼事?#
JAX 的追蹤基礎架構的一項變更,稱為「omnistaging」(jax-ml/jax#3370)已在 jax==0.2.0 中開啟。此變更改善了記憶體效能、追蹤執行時間,並簡化了 jax 內部機制,但可能會導致一些現有程式碼中斷。中斷通常是錯誤程式碼的結果,因此從長遠來看,最好修正錯誤,但 omnistaging 也可以作為臨時的權宜之計停用。我們很樂意協助您進行修正!
我如何知道 omnistaging 破壞了我的程式碼?#
判斷是否是 omnistaging 造成問題的最簡單方法是停用 omnistaging,看看問題是否消失。請參閱下方的當 omnistaging 開啟時,可能會出現哪些問題?章節。
我現在可以停用 omnistaging 嗎?#
注意:這適用於 JAX 版本 0.2.0 到 0.2.11;在 JAX 版本 0.2.12 及更高版本中,無法停用 omnistaging
暫時可以透過以下方式停用 omnistaging:
將 shell 環境變數
JAX_OMNISTAGING
設定為 falsy 值;如果您的程式碼使用 absl 解析旗標,則將布林旗標
jax_omnistaging
設定為 falsy 值;在您的主要檔案頂端附近使用此陳述式
jax.config.disable_omnistaging()
我該如何修正 omnistaging 暴露的錯誤?#
到目前為止,omnistaging 最常見的問題是使用 jax.numpy
計算形狀值或其他追蹤時間常數。請參閱下方的程式碼區塊以取得快速範例,如需完整詳細資訊以及其他問題,請參閱當 omnistaging 開啟時,可能會出現哪些問題?章節。
不要這樣做
@jit
def f(x):
input_size = jnp.prod(x.shape)
if input_size > 100:
...
改為這樣做
import numpy as np
@jit
def f(x):
input_size = np.prod(x.shape)
if input_size > 100:
...
不要將 jax.numpy
視為 numpy
的直接替代品,現在最好將 jax.numpy
操作僅用於想要在加速器(例如您的 GPU)上執行計算時。
什麼是「omnistaging」,它為什麼有用?#
Omnistaging 是 JAX 核心升級的名稱,旨在將更多計算從逐操作的 Python 暫存到 XLA,並避免 jit
、pmap
和控制流程 primitives 中的任何「追蹤時間常數折疊」。因此,omnistaging 透過減少追蹤期間的碎片化,以及為 XLA 產生更少的編譯時大型常數,來改善 JAX 的記憶體效能(有時顯著地)。它還可以透過消除追蹤時的逐操作執行來改善追蹤效能。此外,omnistaging 簡化了 JAX 核心內部機制,修正了許多未解決的錯誤,並為即將推出的重要功能奠定了基礎。
名稱「omnistaging」表示暫存所有可能的東西。
玩具範例#
JAX 轉換(例如 jit
和 pmap
)將計算暫存到 XLA。也就是說,我們將它們應用於包含多個 primitive 操作的函數,以便操作不是從 Python 一次執行一個,而是都成為一個端對端優化的 XLA 計算的一部分。
但是究竟哪些操作被暫存?在 omnistaging 之前,JAX 僅根據資料依賴性來暫存計算。以下是一個範例函數,以及在 omnistaging 變更之前它暫存的 XLA HLO 程式
from jax import jit
import jax.numpy as jnp
@jit
def f(x):
y = jnp.add(1, 1)
return x * y
f(3)
ENTRY jit_f.6 {
constant.2 = pred[] constant(false)
parameter.1 = s32[] parameter(0)
constant.3 = s32[] constant(2)
multiply.4 = s32[] multiply(parameter.1, constant.3)
ROOT tuple.5 = (s32[]) tuple(multiply.4)
}
請注意,add
操作未被暫存。相反地,我們只看到一個乘法。
以下是從此函數在 omnistaging 變更之後產生的 HLO
ENTRY jit_f.8 {
constant.2 = pred[] constant(false)
parameter.1 = s32[] parameter(0)
constant.3 = s32[] constant(1)
constant.4 = s32[] constant(1)
add.5 = s32[] add(constant.3, constant.4)
multiply.6 = s32[] multiply(parameter.1, add.5)
ROOT tuple.7 = (s32[]) tuple(multiply.6)
}
稍微不那麼玩具的範例#
以下是一個不那麼玩具的範例,當我們想要建立布林遮罩時,在實務中可能會出現
import jax.numpy as jnp
from jax import lax
@jit
def select_tril(x):
mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
return lax.select(mask, x, jnp.zeros_like(x)) # lax.select is like jnp.where
x = np.arange(12).reshape((3, 4))
select_tril(x)
在 omnistaging 之前
ENTRY jit_select_tril.8 {
constant.3 = pred[] constant(false)
constant.1 = pred[3,4]{1,0} constant({...})
parameter.2 = s32[3,4]{1,0} parameter(0)
constant.4 = s32[] constant(0)
broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}
select
操作被暫存,但用於建構常數 mask
的操作則沒有。用於建構 mask
的操作不是被暫存,而是在 Python 追蹤時間逐操作地執行,XLA 只看到一個編譯時常數 constant.1
,代表 mask
的值。這很遺憾,因為如果我們暫存了用於建構 mask
的操作,XLA 可以將它們融合到 select
中,並完全避免實現結果。因此,我們最終浪費了記憶體在可能很大的常數上,浪費了時間分派多個未融合的逐操作 XLA 計算,甚至可能造成記憶體碎片化。
(對應於 jnp.zeros_like(x)
的零陣列建構的 broadcast
被暫存,因為 JAX 對於來自 jax-ml/jax#1668 的非常簡單的表達式很寬鬆。在 omnistaging 之後,我們可以移除該寬鬆的子語言並簡化 JAX 內部機制。)
mask
的建立未被暫存的原因是,在 omnistaging 之前,jit
基於資料依賴性運作。也就是說,jit
僅暫存函數中那些對引數具有資料依賴性的操作。控制流程 primitives 和 pmap
的行為類似。在 select_tril
的情況下,建構常數 mask
的操作對引數 x 沒有資料依賴性,因此它們不會被暫存;只有 lax.select
呼叫具有資料依賴性。
透過 omnistaging,jit
轉換函數的動態上下文中的所有 jax.numpy
呼叫都會被暫存到 XLA。也就是說,在 omnistaging 之後,XLA 看到的 select_tril
計算是
ENTRY jit_select_tril.16 {
constant.4 = pred[] constant(false)
iota.1 = s32[3]{0} iota(), iota_dimension=0
broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
reshape.7 = s32[3]{0} reshape(broadcast.5)
broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
iota.2 = s32[4]{0} iota(), iota_dimension=0
broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
reshape.9 = s32[4]{0} reshape(broadcast.6)
broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
parameter.3 = s32[3,4]{1,0} parameter(0)
constant.12 = s32[] constant(0)
broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}
當 omnistaging 開啟時,可能會出現哪些問題?#
由於在 jit
或 pmap
的動態上下文中,將所有 jax.numpy
操作從 Python 暫存到 XLA 的結果,一些先前可以運作的程式碼可能會開始引發明顯的錯誤。如下所述,這些行為在 omnistaging 之前已經有錯誤,但 omnistaging 使它們成為硬性錯誤。
使用 jax.numpy
進行形狀計算#
範例#
from jax import jit
import jax.numpy as jnp
@jit
def ex1(x):
size = jnp.prod(jnp.array(x.shape))
return x.reshape((size,))
ex1(jnp.ones((3, 4)))
錯誤訊息#
[... full traceback ...]
File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The error arose in jax.numpy.reshape.
While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:
operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
from line ex1.py:6 (ex1)
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
See https://jax.dev.org.tw/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
說明#
透過 omnistaging,我們無法像上方使用 jnp.prod
那樣使用 jax.numpy
進行形狀計算,因為在 jit 函數的動態上下文中,這些操作將從 Python 中暫存為在執行階段計算的值,但我們需要它們成為編譯時(以及因此追蹤時間)常數。
在 omnistaging 之前,此程式碼不會引發錯誤,但這是一個常見的效能錯誤:jnp.prod
計算會在追蹤時間在裝置上執行,這意味著額外的編譯、傳輸、同步、分配,以及可能的記憶體碎片化。
解決方案#
解決方案很簡單,就是對於這些形狀計算使用原始的 numpy
。我們不僅避免了錯誤,而且還將計算保留在主機上(並降低了額外負擔)。
這個問題在程式碼中非常常見,以至於我們試圖使錯誤訊息特別好。除了顯示抽象追蹤器值導致問題的堆疊追蹤(完整堆疊追蹤中的 jnp.reshape
行,位於 omni.py:10),我們還解釋了為什麼這個值首先變成追蹤器,方法是指向上游的 primitive 操作,該操作導致它變成抽象追蹤器(來自 omni.py:9 上 jnp.prod
的 reduce_prod
),以及追蹤器所屬的 jit
裝飾函數(omni.py:6 上的 ex1
)。
副作用#
範例#
from jax import jit
from jax import random
key = random.PRNGKey(0)
def init():
global key
key, subkey = random.split(key)
return random.normal(subkey, ())
print(init()) # -1.2515389
print(init()) # -0.58665067
init = jit(init)
print(init()) # 0.48648298
print(init()) # 0.48648298 !!
最後一個呼叫具有重複的隨機性,但沒有硬性錯誤,因為我們沒有重新執行 Python。但是,如果我們查看 key
,我們會看到一個逸出的追蹤器當 omnistaging 開啟時
print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
在 omnistaging 之前,random.split
呼叫不會被暫存,因此我們不會得到逸出的追蹤器。程式碼仍然會有錯誤,因為 jitted 函數不會重現原始函數的語意(因為重複使用相同的 PRNG 金鑰),最終是由於副作用。
當 omnistaging 開啟時,如果我們再次接觸 key
,我們將收到逸出的追蹤器錯誤
random.normal(key, ())
錯誤訊息#
[... full stack trace …]
File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).
說明#
我們發現的第二大類 omnistaging 問題與副作用程式碼有關。此程式碼已經透過轉換 effectful 函數而使 JAX 保固失效,但由於 omnistaging 之前的「追蹤時間常數折疊」行為,某些副作用函數仍然可以正確運作。Omnistaging 捕捉到更多此類錯誤。
解決方案#
解決方案是識別依賴副作用的 JAX 轉換函數,並重新編寫它們以使其沒有副作用。
基於 XLA 優化的微小數值差異#
因為透過 omnistaging,更多的計算被暫存到 XLA,而不是一些在追蹤時間執行的計算,這可能會重新排序浮點運算。因此,我們已經看到數值行為發生變化,導致公差過於嚴格的測試在 omnistaging 開啟時失敗。
依賴已變更的 JAX 內部 API#
Omnistaging 涉及對 JAX 核心程式碼進行一些重大修訂,包括移除或變更內部函數。任何依賴此類內部 JAX API 的程式碼都可能在 omnistaging 開啟時中斷,無論是建置錯誤 (來自 pytype) 還是執行階段錯誤。
觸發 XLA 編譯時錯誤#
由於 omnistaging 涉及將更多程式碼暫存到 XLA,因此我們已經看到它在某些後端觸發了先前存在的 XLA 編譯時錯誤。對於這些錯誤,最好的做法是回報它們,以便我們可以與 XLA 團隊合作進行修正。