非同步分派#
JAX 使用非同步分派來隱藏 Python 的額外負擔。考慮以下程式
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from jax import random
>>> x = random.uniform(random.key(0), (1000, 1000))
>>> # Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
>>> # will block until the value is ready.
>>> jnp.dot(x, x) + 3.
Array([[258.01971436, 249.64862061, 257.13372803, ...,
236.67948914, 250.68939209, 241.36853027],
[265.65979004, 256.28912354, 262.18252563, ...,
242.03181458, 256.16757202, 252.44122314],
[262.38916016, 255.72747803, 261.23059082, ...,
240.83563232, 255.41094971, 249.62471008],
...,
[259.15814209, 253.09197998, 257.72174072, ...,
242.23876953, 250.72680664, 247.16642761],
[271.22662354, 261.91204834, 265.33398438, ...,
248.26651001, 262.05389404, 261.33700562],
[257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
248.62597656, 243.22348022]], dtype=float32)
當執行諸如 jnp.dot(x, x)
的操作時,JAX 不會等待操作完成才將控制權返回給 Python 程式。相反地,JAX 會返回一個 jax.Array
值,這是一個 future,也就是一個將來會在加速器裝置上產生但不必立即可用的值。我們可以檢查 jax.Array
的形狀或型別,而無需等待產生它的計算完成,我們甚至可以將它傳遞給另一個 JAX 計算,就像我們在這裡對加法運算所做的那樣。只有當我們實際從主機檢查陣列的值時,例如透過列印它或將其轉換為普通的 numpy.ndarray
,JAX 才會強制 Python 程式碼等待計算完成。
非同步分派很有用,因為它允許 Python 程式碼「超前」加速器裝置執行,使 Python 程式碼脫離關鍵路徑。如果 Python 程式碼在裝置上排隊工作的速度快於執行速度,並且 Python 程式碼實際上不需要在主機上檢查計算的輸出,那麼 Python 程式可以排隊任意數量的工作,並避免讓加速器等待。
非同步分派對於微基準測試有一個稍微令人驚訝的後果。
>>> %time jnp.dot(x, x)
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
Wall time: 269 µs
Array([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
269µs 對於 CPU 上 1000x1000 矩陣乘法來說是一個非常小的時間!然而,事實證明非同步分派誤導了我們,我們沒有計時矩陣乘法的執行時間,而只是分派工作的時間。為了測量操作的真實成本,我們必須在主機上讀取該值(例如,將其轉換為普通的主機端 numpy 陣列),或者使用 jax.Array
值上的 block_until_ready()
方法來等待產生它的計算完成。
>>> %time np.asarray(jnp.dot(x, x))
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
238.36853],
[262.6598 , 253.28911, 259.18253, ..., 239.03181, 253.16757,
249.44122],
[259.38916, 252.72748, 258.2306 , ..., 237.83563, 252.41095,
246.62471],
...,
[256.15814, 250.09198, 254.72173, ..., 239.23877, 247.7268 ,
244.16643],
[268.22662, 258.91205, 262.33398, ..., 245.26651, 259.0539 ,
258.337 ],
[254.16135, 251.75433, 256.083 , ..., 238.59848, 245.62598,
240.22348]], dtype=float32)
>>> %time jnp.dot(x, x).block_until_ready()
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
Wall time: 4.92 ms
Array([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)
在不將結果傳輸回 Python 的情況下進行封鎖通常更快,並且通常是編寫計算時間微基準測試的最佳選擇。