常見問題 (FAQ)#
我們在此收集常見問題的解答。歡迎貢獻!
jit
變更了我的函式行為#
如果您在使用 jax.jit()
後,Python 函式行為發生變更,則可能是您的函式使用了全域狀態,或具有副作用。在以下程式碼中,impure_func
使用了全域 y
,並且由於 print
而具有副作用
y = 0
# @jit # Different behavior with jit
def impure_func(x):
print("Inside:", y)
return x + y
for y in range(3):
print("Result:", impure_func(y))
不使用 jit
時,輸出為
Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4
使用 jit
時,輸出為
Inside: 0
Result: 0
Result: 1
Result: 2
對於 jax.jit()
,函式會使用 Python 解譯器執行一次,此時會發生 Inside
列印,並且觀察到 y
的第一個值。然後,函式會被編譯和快取,並使用不同的 x
值執行多次,但使用相同的 y
第一個值。
延伸閱讀
jit
變更了輸出的精確數值#
有時使用者會驚訝於使用 jit()
包裝函式會變更函式的輸出。例如
>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
... return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x))
0.5723649
輸出中的這種細微差異來自 XLA 編譯器內的最佳化:在編譯期間,XLA 有時會重新排列或省略某些操作,以使整體計算更有效率。
在本例中,XLA 利用對數的屬性將 log(sqrt(x))
替換為 0.5 * log(x)
,這是一個數學上相同的表達式,可以比原始表達式更有效率地計算。輸出差異來自於浮點運算僅是實際數學的近似值,因此計算相同表達式的不同方式可能會產生細微不同的結果。
其他時候,XLA 的最佳化可能會導致更顯著的差異。考慮以下範例
>>> def f(x):
... return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0
在非 JIT 編譯的逐操作模式下,結果為 inf
,因為 jnp.exp(x)
溢位並傳回 inf
。但是,在 JIT 下,XLA 識別到 log
是 exp
的反函數,並從編譯後的函式中移除這些操作,僅傳回輸入。在本例中,JIT 編譯產生了更準確的實際結果浮點近似值。
不幸的是,XLA 的代數簡化完整列表沒有充分的文件記錄,但如果您熟悉 C++ 並且好奇 XLA 編譯器進行哪些類型的最佳化,您可以在原始碼中查看它們:algebraic_simplifier.cc。
jit
裝飾的函式編譯速度非常慢#
如果您的 jit
裝飾函式在您第一次呼叫時需要數十秒(或更長時間)才能執行,但在再次呼叫時執行速度很快,則表示 JAX 花費了很長時間來追蹤或編譯您的程式碼。
這通常表示呼叫您的函式會在 JAX 的內部表示中產生大量程式碼,通常是因為它大量使用了 Python 控制流程,例如 for
迴圈。對於少數迴圈迭代,Python 還可以,但如果您需要許多迴圈迭代,您應該重寫程式碼以使用 JAX 的結構化控制流程 primitives(例如 lax.scan()
),或避免使用 jit
包裝迴圈(您仍然可以在迴圈內部使用 jit
裝飾的函式)。
如果您不確定這是否是問題所在,您可以嘗試在您的函式上執行 jax.make_jaxpr()
。如果輸出長達數百或數千行,您可以預期編譯速度會很慢。
有時,由於您的程式碼使用了許多形狀不同的陣列,因此不清楚如何重寫程式碼以避免 Python 迴圈。在這種情況下,建議的解決方案是使用 jax.numpy.where()
等函式,在具有固定形狀的填充陣列上執行計算。
如果您的函式由於其他原因而編譯緩慢,請在 GitHub 上開啟 issue。
如何搭配方法使用 jit
?#
jax.jit()
的大多數範例都與裝飾獨立 Python 函式有關,但在類別中裝飾方法會引入一些複雜性。例如,考慮以下簡單類別,我們在方法上使用了標準 jit()
註解
>>> import jax.numpy as jnp
>>> from jax import jit
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit # <---- How to do this correctly?
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
但是,當您嘗試呼叫此方法時,此方法會導致錯誤
>>> c = CustomClass(2, True)
>>> c.calc(3)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.
問題在於函式的第一個引數是 self
,其型別為 CustomClass
,而 JAX 不知道如何處理此型別。在這種情況下,我們可以使用三種基本策略,我們將在下面討論它們。
策略 1:JIT 編譯的輔助函式#
最直接的方法是建立一個類別外部的輔助函式,該函式可以使用正常方式進行 JIT 裝飾。例如
>>> from functools import partial
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... def calc(self, y):
... return _calc(self.mul, self.x, y)
>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
... if mul:
... return x * y
... return y
結果將如預期般運作
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
這種方法的好處是它簡單、明確,並且避免了教導 JAX 如何處理 CustomClass
型別物件的需求。但是,您可能希望將所有方法邏輯保留在同一個位置。
策略 2:將 self
標記為靜態#
另一個常見模式是使用 static_argnums
將 self
引數標記為靜態。但必須謹慎執行此操作,以避免產生非預期的結果。您可能會想簡單地執行此操作
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... # WARNING: this example is broken, as we'll see below. Don't copy & paste!
... @partial(jit, static_argnums=0)
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
如果您呼叫該方法,它將不再引發錯誤
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
但是,有一個陷阱:如果您在第一次方法呼叫後變更物件,則後續方法呼叫可能會傳回不正確的結果
>>> c.mul = False
>>> print(c.calc(3)) # Should print 3
6
這是為什麼?當您將物件標記為靜態時,它實際上將用作 JIT 內部編譯快取中的字典金鑰,這表示其雜湊(即 hash(obj)
)相等性(即 obj1 == obj2
)和物件識別(即 obj1 is obj2
)將被假定為具有一致的行為。自訂物件的預設 __hash__
是其物件 ID,因此 JAX 無法知道變更的物件是否應觸發重新編譯。
您可以透過為物件定義適當的 __hash__
和 __eq__
方法來部分解決此問題;例如
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @partial(jit, static_argnums=0)
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
... def __hash__(self):
... return hash((self.x, self.mul))
...
... def __eq__(self, other):
... return (isinstance(other, CustomClass) and
... (self.x, self.mul) == (other.x, other.mul))
(有關覆寫 __hash__
時需求的更多討論,請參閱 object.__hash__()
文件)。
只要您永遠不要變更您的物件,這應該可以與 JIT 和其他轉換正確協同運作。用作雜湊金鑰的物件變更會導致幾個微妙的問題,這就是為什麼例如可變 Python 容器(例如 dict
、list
)未定義 __hash__
,而它們的不可變對應項(例如 tuple
)則定義了。
如果您的類別依賴就地變更(例如在其方法中設定 self.attr = ...
),則您的物件並非真正「靜態」,將其標記為靜態可能會導致問題。幸運的是,這種情況還有另一種選擇。
策略 3:將 CustomClass
設為 PyTree#
正確 JIT 編譯類別方法的最彈性方法是將型別註冊為自訂 PyTree 物件;請參閱擴充 pytrees。這可讓您明確指定類別的哪些元件應視為靜態,哪些應視為動態。以下是其外觀範例
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
... def _tree_flatten(self):
... children = (self.x,) # arrays / dynamic values
... aux_data = {'mul': self.mul} # static values
... return (children, aux_data)
...
... @classmethod
... def _tree_unflatten(cls, aux_data, children):
... return cls(*children, **aux_data)
>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
... CustomClass._tree_flatten,
... CustomClass._tree_unflatten)
這當然更複雜,但它解決了與上述更簡單方法相關的所有問題
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
>>> c.mul = False # mutation is detected
>>> print(c.calc(3))
3
>>> c = CustomClass(jnp.array(2), True) # non-hashable x is supported
>>> print(c.calc(3))
6
只要您的 tree_flatten
和 tree_unflatten
函式正確處理類別中的所有相關屬性,您應該能夠直接將此型別的物件用作 JIT 編譯函式的引數,而無需任何特殊註解。
控制裝置上的資料和計算放置#
讓我們先看看 JAX 中資料和計算放置的原則。
在 JAX 中,計算遵循資料放置。JAX 陣列具有兩個放置屬性:1) 資料所在的裝置;以及 2) 是否已提交到裝置(資料有時稱為黏著到裝置)。
預設情況下,JAX 陣列會以未提交狀態放置在預設裝置 (jax.devices()[0]
) 上,預設情況下,預設裝置是第一個 GPU 或 TPU。如果沒有 GPU 或 TPU,則 jax.devices()[0]
是 CPU。可以使用 jax.default_device()
內容管理器暫時覆寫預設裝置,或透過將環境變數 JAX_PLATFORMS
或 absl 標誌 --jax_platforms
設定為 “cpu”、“gpu” 或 “tpu” 來為整個程序設定預設裝置 (JAX_PLATFORMS
也可以是平台列表,這決定了哪些平台按優先順序可用)。
>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())
{CudaDevice(id=0)}
涉及未提交資料的計算會在預設裝置上執行,結果會以未提交狀態放置在預設裝置上。
也可以使用 jax.device_put()
和 device
參數將資料明確放置在裝置上,在這種情況下,資料會提交到裝置
>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])
>>> print(arr.devices())
{CudaDevice(id=2)}
涉及某些已提交輸入的計算將在已提交的裝置上發生,結果將以已提交狀態放置在同一裝置上。在提交到多個裝置的引數上調用操作會引發錯誤。
您也可以使用不帶 device
參數的 jax.device_put()
。如果資料已在裝置上(已提交或未提交),則會保持不變。如果資料不在任何裝置上,也就是說,它是常規 Python 或 NumPy 值,則會以未提交狀態放置在預設裝置上。
Jitted 函式的行為與任何其他基本操作類似,它們將遵循資料,如果在提交到多個裝置的資料上調用,則會顯示錯誤。
(在 2021 年 3 月的 PR #6002 之前,陣列常數的建立存在一些延遲,因此 jax.device_put(jnp.zeros(...), jax.devices()[1])
或類似程式碼實際上會在 jax.devices()[1]
上建立零陣列,而不是在預設裝置上建立陣列然後移動它。但是,此最佳化已移除,以簡化實作。)
(截至 2020 年 4 月,jax.jit()
具有影響裝置放置的 device 參數。該參數是實驗性的,可能會被移除或變更,不建議使用。)
對於已完成的範例,我們建議閱讀 test_computation_follows_data
in multi_device_test.py。
JAX 程式碼基準測試#
您剛剛將一個棘手的函式從 NumPy/SciPy 移植到 JAX。這實際上加快了速度嗎?
在測量使用 JAX 的程式碼速度時,請記住與 NumPy 的這些重要差異
JAX 程式碼是即時 (JIT) 編譯的。 大多數以 JAX 撰寫的程式碼都可以以支援 JIT 編譯的方式撰寫,這可以使其執行速度快得多(請參閱 要 JIT 還是不要 JIT)。為了從 JAX 獲得最大效能,您應該在最外層的函式呼叫上套用
jax.jit()
。請記住,第一次執行 JAX 程式碼時,速度會較慢,因為它正在編譯中。即使您不在自己的程式碼中使用
jit
,也是如此,因為 JAX 的內建函式也是 JIT 編譯的。JAX 具有非同步分派。 這表示您需要呼叫
.block_until_ready()
以確保計算實際發生(請參閱 非同步分派)。JAX 預設僅使用 32 位元 dtype。 您可能想要在 NumPy 中明確使用 32 位元 dtype,或在 JAX 中啟用 64 位元 dtype(請參閱 雙精度 (64 位元))以進行公平比較。
在 CPU 和加速器之間傳輸資料需要時間。 如果您只想測量評估函式所需的時間,您可能需要先將資料傳輸到您想要在其上執行的裝置(請參閱 控制裝置上的資料和計算放置)。
以下是如何將所有這些技巧組合到一個微基準測試中,以比較 JAX 與 NumPy,並使用 IPython 方便的 %time 和 %timeit magics 的範例
import numpy as np
import jax.numpy as jnp
import jax
def f(x): # function we're benchmarking (works in both NumPy & JAX)
return x.T @ (x - x.mean(axis=0))
x_np = np.ones((1000, 1000), dtype=np.float32) # same as JAX default dtype
%timeit f(x_np) # measure NumPy runtime
%time x_jax = jax.device_put(x_np) # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready() # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready() # measure JAX runtime
在 Colab 中使用 GPU 執行時,我們看到
NumPy 在 CPU 上每次評估需要 16.2 毫秒
JAX 需要 1.26 毫秒才能將 NumPy 陣列複製到 GPU
JAX 需要 193 毫秒才能編譯函式
JAX 在 GPU 上每次評估需要 485 微秒
在本例中,我們看到一旦資料傳輸完成且函式編譯完成,GPU 上的 JAX 對於重複評估而言速度快約 30 倍。
這是一個公平的比較嗎?也許是。最終重要的效能是用於執行完整的應用程式,這不可避免地包括一定數量的資料傳輸和編譯。此外,我們謹慎地選擇了足夠大的陣列 (1000x1000) 和足夠密集的計算(@
運算子正在執行矩陣-矩陣乘法),以攤銷 JAX/加速器與 NumPy/CPU 相比增加的額外負擔。例如,如果我們將此範例切換為使用 10x10 輸入,則 JAX/GPU 的執行速度比 NumPy/CPU 慢 10 倍(100 微秒 vs 10 微秒)。
JAX 比 NumPy 快嗎?#
使用者經常嘗試使用此類基準測試回答的一個問題是 JAX 是否比 NumPy 快;由於這兩個套件的差異,沒有簡單的答案。
廣義來說
NumPy 操作是急切、同步地執行,並且僅在 CPU 上執行。
JAX 操作可能會在編譯後急切執行或執行(如果在
jit()
內部);它們是非同步分派的(請參閱 非同步分派);並且它們可以在 CPU、GPU 或 TPU 上執行,每個都有截然不同且不斷發展的效能特性。
這些架構差異使得 NumPy 和 JAX 之間有意義的直接基準比較變得困難。
此外,這些差異導致了套件之間不同的工程重點:例如,NumPy 已投入大量精力來減少個別陣列操作的每次呼叫分派額外負擔,因為在 NumPy 的計算模型中,這種額外負擔是無法避免的。另一方面,JAX 有幾種方法可以避免分派額外負擔(例如 JIT 編譯、非同步分派、批次轉換等),因此減少每次呼叫的額外負擔已不再是優先事項。
記住所有這些,總之:如果您在 CPU 上執行個別陣列操作的微基準測試,由於 NumPy 的每次操作分派額外負擔較低,您通常可以預期 NumPy 的效能會優於 JAX。如果您在 GPU 或 TPU 上執行程式碼,或在 CPU 上基準測試更複雜的 JIT 編譯操作序列,您通常可以預期 JAX 的效能會優於 NumPy。
不同種類的 JAX 值#
在轉換函式的過程中,JAX 會將一些函式引數替換為特殊的追蹤器值。
如果您使用 print
陳述式,您可能會看到這一點
def func(x):
print(x)
return jnp.cos(x)
res = jax.jit(func)(0.)
上述程式碼確實會回傳正確的值 1.
,但同時也會印出 Traced<ShapedArray(float32[])>
作為 x
的值。一般來說,JAX 會以透明的方式在內部處理這些追蹤器 (tracer) 值,例如,在用於實作 jax.numpy
函式的數值 JAX 原語中。這就是為什麼 jnp.cos
在上述範例中可以運作的原因。
更精確地說,追蹤器 (tracer) 值會被引入作為 JAX 轉換函式的引數,但由特殊參數 (例如 jax.jit()
的 static_argnums
或 jax.pmap()
的 static_broadcasted_argnums
) 識別的引數除外。通常,涉及至少一個追蹤器值的計算會產生一個追蹤器值。除了追蹤器值之外,還有常規 (regular) Python 值:在 JAX 轉換之外計算的值,或從上述某些 JAX 轉換的靜態引數中產生,或僅從其他常規 Python 值計算而來的值。這些值是在沒有 JAX 轉換的情況下到處使用的值。
追蹤器值帶有一個抽象 (abstract) 值,例如,具有陣列形狀和 dtype 資訊的 ShapedArray
。我們在這裡將此類追蹤器稱為抽象追蹤器 (abstract tracers)。有些追蹤器,例如,為自動微分轉換的引數引入的追蹤器,帶有 ConcreteArray
抽象值,實際上包含了常規陣列資料,並用於例如解析條件式。我們在這裡將此類追蹤器稱為具體追蹤器 (concrete tracers)。從這些具體追蹤器計算出的追蹤器值,或許與常規值結合使用,會產生具體追蹤器。具體值 (concrete value) 要么是常規值,要么是具體追蹤器。
大多數情況下,從追蹤器值計算出的值本身就是追蹤器值。只有極少數例外,當計算可以完全使用追蹤器攜帶的抽象值完成時,在這種情況下,結果可以是常規值。例如,取得具有 ShapedArray
抽象值的追蹤器的形狀。另一個範例是將具體追蹤器值顯式轉換為常規類型,例如 int(x)
或 x.astype(float)
。另一種情況是 bool(x)
,當具體性使其成為可能時,它會產生一個 Python 布林值。這種情況尤其顯著,因為它經常在控制流程中出現。
以下說明轉換如何引入抽象或具體追蹤器
jax.jit()
:為所有位置引數引入抽象追蹤器 (abstract tracers),但由static_argnums
表示的引數除外,這些引數仍為常規值。jax.pmap()
:為所有位置引數引入抽象追蹤器 (abstract tracers),但由static_broadcasted_argnums
表示的引數除外。jax.vmap()
、jax.make_jaxpr()
、xla_computation()
:為所有位置引數引入抽象追蹤器 (abstract tracers)。jax.jvp()
和jax.grad()
為所有位置引數引入具體追蹤器 (concrete tracers)。例外情況是當這些轉換位於外部轉換內,並且實際引數本身是抽象追蹤器時;在這種情況下,由自動微分轉換引入的追蹤器也是抽象追蹤器。所有高階控制流程原語 (
lax.cond()
、lax.while_loop()
、lax.fori_loop()
、lax.scan()
) 在處理 functionals 時會引入抽象追蹤器 (abstract tracers),無論 JAX 轉換是否正在進行中。
當您的程式碼只能在常規 Python 值上運作時,所有這些都相關,例如基於資料具有條件控制流程的程式碼
def divide(x, y):
return x / y if y >= 1. else 0.
如果我們想要套用 jax.jit()
,我們必須確保指定 static_argnums=1
以確保 y
保持為常規值。這是由於布林運算式 y >= 1.
,它需要具體值 (常規值或追蹤器)。如果我們顯式寫入 bool(y >= 1.)
、int(y)
或 float(y)
,也會發生同樣的情況。
有趣的是,jax.grad(divide)(3., 2.)
可以運作,因為 jax.grad()
使用具體追蹤器,並使用 y
的具體值解析條件式。
緩衝區捐贈#
當 JAX 執行計算時,它會在裝置上為所有輸入和輸出使用緩衝區。如果您知道在計算後不再需要其中一個輸入,並且如果它與其中一個輸出的形狀和元素類型匹配,您可以指定您希望捐贈對應的輸入緩衝區以容納輸出。這將減少執行所需的記憶體,減少量為捐贈緩衝區的大小。
如果您有類似以下模式的東西,您可以使用緩衝區捐贈
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)
您可以將此視為在不可變的 JAX 陣列上執行記憶體效率高的函數式更新的一種方式。在計算的邊界內,XLA 可以為您進行此最佳化,但在 jit/pmap 邊界,您需要向 XLA 保證在呼叫捐贈函式後您將不會使用捐贈的輸入緩衝區。
您可以透過使用 donate_argnums 參數來實現這一點,該參數適用於函式 jax.jit()
、jax.pjit()
和 jax.pmap()
。此參數是一個索引序列 (從 0 開始),指向位置引數列表
def add(x, y):
return x + y
x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)
請注意,當使用關鍵字引數呼叫您的函式時,目前這不起作用!以下程式碼不會捐贈任何緩衝區
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)
如果其緩衝區被捐贈的引數是一個 pytree,則其組件的所有緩衝區都會被捐贈
def add_ones(xs: List[Array]):
return [x + 1 for x in xs]
xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs)
不允許捐贈隨後在計算中使用的緩衝區,JAX 會給出錯誤,因為在捐贈後 y 的緩衝區已失效
# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1 # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer
如果捐贈的緩衝區未使用,您會收到警告,例如,因為捐贈的緩衝區多於可用於輸出的緩衝區
# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}
如果沒有輸出的形狀與捐贈的形狀匹配,則捐贈也可能未使用
y = jax.device_put(np.ones((1, 3))) # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}
梯度在使用 where
的地方包含 NaN#
如果您使用 where
定義函式以避免未定義的值,如果您不小心,您可能會在反向微分中獲得 NaN
def my_log(x):
return jnp.where(x > 0., jnp.log(x), 0.)
my_log(0.) ==> 0. # Ok
jax.grad(my_log)(0.) ==> NaN
一個簡短的解釋是,在 grad
計算期間,對應於未定義的 jnp.log(x)
的伴隨值 (adjoint) 是 NaN
,並且它會累積到 jnp.where
的伴隨值。編寫此類函式的正確方法是確保在部分定義的函式內部有一個 jnp.where
,以確保伴隨值始終是有限的
def safe_for_grad_log(x):
return jnp.log(jnp.where(x > 0., x, 1.))
safe_for_grad_log(0.) ==> 0. # Ok
jax.grad(safe_for_grad_log)(0.) ==> 0. # Ok
除了原始的 jnp.where
之外,可能還需要內部的 jnp.where
,例如
def my_log_or_y(x, y):
"""Return log(x) if x > 0 or y"""
return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)
延伸閱讀
為什麼基於排序順序的函式的梯度為零?#
如果您定義一個使用依賴於輸入的相對順序的操作 (例如 max
、greater
、argsort
等) 處理輸入的函式,那麼您可能會驚訝地發現梯度處處為零。這是一個範例,我們定義 f(x) 為一個步階函數,當 x 為負數時回傳 0,當 x 為正數時回傳 1
import jax
import numpy as np
import jax.numpy as jnp
def f(x):
return (x > 0).astype(float)
df = jax.vmap(jax.grad(f))
x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])
print(f"f(x) = {f(x)}")
# f(x) = [0. 0. 0. 1. 1.]
print(f"df(x) = {df(x)}")
# df(x) = [0. 0. 0. 0. 0.]
梯度處處為零的事實起初可能會令人困惑:畢竟,輸出確實會隨著輸入而改變,那麼梯度怎麼可能為零?但是,在這種情況下,零被證明是正確的結果。
為什麼會這樣?請記住,微分正在測量在給定 x
的無窮小變化時 f
的變化。對於 x=1.0
,f
回傳 1.0
。如果我們擾動 x
使其稍微大一點或小一點,這不會改變輸出,因此根據定義,grad(f)(1.0)
應該為零。相同的邏輯適用於所有大於零的 f
值:無窮小地擾動輸入不會改變輸出,因此梯度為零。同樣地,對於所有小於零的 x
值,輸出為零。擾動 x
不會改變此輸出,因此梯度為零。這讓我們剩下 x=0
這個棘手的情況。當然,如果您向上擾動 x
,它會改變輸出,但這是有問題的:x
的無窮小變化會產生函數值的有限變化,這意味著梯度是未定義的。幸運的是,我們有另一種方法可以在這種情況下測量梯度:我們向下擾動函數,在這種情況下,輸出不會改變,因此梯度為零。JAX 和其他自動微分系統傾向於以這種方式處理不連續性:如果正梯度和負梯度不一致,但其中一個已定義而另一個未定義,我們使用已定義的那一個。根據梯度的此定義,從數學和數值上來說,此函式的梯度處處為零。
問題源於我們的函數在 x = 0
處具有不連續性。我們的 f
在這裡本質上是一個 Heaviside 步階函數,我們可以將 Sigmoid 函數 用作平滑的替代品。當 x 遠離零時,sigmoid 大致等於 heaviside 函數,但它用平滑、可微分的曲線取代了 x = 0
處的不連續性。由於使用了 jax.nn.sigmoid()
,我們得到了類似的計算,具有明確定義的梯度
def g(x):
return jax.nn.sigmoid(x)
dg = jax.vmap(jax.grad(g))
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
with np.printoptions(suppress=True, precision=2):
print(f"g(x) = {g(x)}")
# g(x) = [0. 0.27 0.5 0.73 1. ]
print(f"dg(x) = {dg(x)}")
# dg(x) = [0. 0.2 0.25 0.2 0. ]
jax.nn
子模組也具有其他常見的基於排序的函數的平滑版本,例如 jax.nn.softmax()
可以取代 jax.numpy.argmax()
的使用,jax.nn.soft_sign()
可以取代 jax.numpy.sign()
的使用,jax.nn.softplus()
或 jax.nn.squareplus()
可以取代 jax.nn.relu()
的使用,等等。
如何將 JAX Tracer 轉換為 NumPy 陣列?#
當在執行階段檢查轉換後的 JAX 函數時,您會發現陣列值被 Tracer
物件取代
@jax.jit
def f(x):
print(type(x))
return x
f(jnp.arange(5))
這會印出以下內容
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
一個常見的問題是如何將這樣的追蹤器轉換回普通的 NumPy 陣列。簡而言之,不可能將 Tracer 轉換為 NumPy 陣列,因為追蹤器是具有給定形狀和 dtype 的每個可能值的抽象表示,而 numpy 陣列是該抽象類別的具體成員。有關追蹤器如何在 JAX 轉換的上下文中運作的更多討論,請參閱 JIT 機制。
將 Tracer 轉換回陣列的問題通常出現在另一個目標的上下文中,該目標與在執行階段存取計算中的中間值有關。例如
如果您希望在執行階段印出追蹤值以進行偵錯,您可以考慮使用
jax.debug.print()
。如果您希望在轉換後的 JAX 函數中呼叫非 JAX 程式碼,您可以考慮使用
jax.pure_callback()
,範例可在 Pure callback example 中找到。如果您希望在執行階段輸入或輸出陣列緩衝區 (例如,從檔案載入資料,或將陣列內容記錄到磁碟),您可以考慮使用
jax.experimental.io_callback()
,範例可在 IO callback example 中找到。
有關執行階段回呼及其使用範例的更多資訊,請參閱 JAX 中的外部回呼。
為什麼某些 CUDA 函式庫載入/初始化失敗?#
在解析動態函式庫時,JAX 使用常用的 動態連結器搜尋模式。JAX 設定 RPATH
指向 pip 安裝的 NVIDIA CUDA 套件的 JAX 相對位置,如果已安裝,則優先使用它們。如果 ld.so
無法在其常用的搜尋路徑中找到您的 CUDA 執行階段函式庫,那麼您必須在 LD_LIBRARY_PATH
中顯式包含這些函式庫的路徑。確保您的 CUDA 檔案可被發現的最簡單方法是簡單地安裝 nvidia-*-cu12
pip 套件,這些套件包含在標準 jax[cuda_12]
安裝選項中。
有時,即使您已確保您的執行階段函式庫可被發現,在載入或初始化它們時仍可能存在一些問題。此類問題的常見原因是執行階段 CUDA 函式庫初始化時記憶體不足。有時會發生這種情況,因為 JAX 會為更快的執行速度預先分配過大的目前可用裝置記憶體區塊,偶爾會導致剩餘可用於執行階段 CUDA 函式庫初始化的記憶體不足。
當執行多個 JAX 實例、與 TensorFlow 並行執行 JAX (TensorFlow 執行自己的預先分配),或在 GPU 被其他進程大量使用的系統上執行 JAX 時,這種情況尤其可能發生。如有疑問,請嘗試再次執行程式,並減少預先分配,可以透過將 XLA_PYTHON_CLIENT_MEM_FRACTION
從預設值 .75
降低,或設定 XLA_PYTHON_CLIENT_PREALLOCATE=false
來實現。有關更多詳細資訊,請參閱關於 JAX GPU 記憶體分配 的頁面。