分散式陣列與自動平行化#
本教學討論透過 jax.Array
實現的平行化,這是 JAX v0.4.1 及更新版本中可用的統一陣列物件模型。
from typing import Optional
import numpy as np
import jax
import jax.numpy as jnp
⚠️ 警告:此筆記本需要 8 個裝置才能執行。
if len(jax.local_devices()) < 8:
raise Exception("Notebook requires 8 devices to run")
簡介與快速範例#
透過閱讀本教學筆記本,您將了解 jax.Array
,這是一種用於表示陣列的統一資料類型,即使其實體儲存跨越多個裝置。您也將了解如何將 jax.Array
與 jax.jit
一起使用,以提供基於編譯器的自動平行化。
在我們逐步思考之前,這裡有一個快速範例。首先,我們將建立一個跨多個裝置分片的 jax.Array
from jax.sharding import PartitionSpec as P, NamedSharding
# Create a Sharding object to distribute a value across devices:
mesh = jax.make_mesh((4, 2), ('x', 'y'))
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
接下來,我們將對其應用計算,並視覺化結果值也如何跨多個裝置儲存
z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
jnp.sin
應用的評估已自動在輸入值 (和輸出值) 儲存的裝置上平行化
# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
The slowest run took 8.96 times longer than the fastest. This could mean that an intermediate result is being cached.
25.2 ms ± 30.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
2.4 ms ± 61.4 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
現在讓我們更詳細地查看每個部分!
計算遵循資料分片並自動平行化#
透過分片的輸入資料,編譯器可以為我們提供平行計算。特別是,使用 jax.jit
裝飾的函數可以在分片陣列上操作,而無需將資料複製到單一裝置上。相反,計算遵循分片:根據輸入資料的分片,編譯器決定中間值和輸出值的分片,並平行化其評估,甚至在必要時插入通訊操作。
例如,最簡單的計算是元素級計算
mesh = jax.make_mesh((4, 2), ('a', 'b'))
x = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
print('input sharding:')
jax.debug.visualize_array_sharding(x)
y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y)
input sharding:
output sharding:
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
在這裡,對於元素級操作 jnp.sin
,編譯器選擇的輸出分片與輸入相同。此外,編譯器自動平行化計算,以便每個裝置平行地從其輸入分片計算其輸出分片。
換句話說,即使我們編寫 jnp.sin
計算時,彷彿單一機器要執行它,編譯器也會為我們分割計算並在多個裝置上執行它。
我們也可以對不僅僅是元素級操作執行相同的操作。考慮具有分片輸入的矩陣乘法
y = jax.device_put(x, NamedSharding(mesh, P('a', None)))
z = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)
w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)
lhs sharding:
rhs sharding:
out sharding:
┌───────────────────────┐ │ TPU 0,1 │ ├───────────────────────┤ │ TPU 2,3 │ ├───────────────────────┤ │ TPU 6,7 │ ├───────────────────────┤ │ TPU 4,5 │ └───────────────────────┘
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
在這裡,編譯器選擇輸出分片,以便它可以最大程度地平行化計算:無需通訊,每個裝置都已擁有計算其輸出分片所需的輸入分片。
我們如何確定它實際上是在平行運行?我們可以進行一個簡單的計時實驗
x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)
┌───────────────────────┐
│ │
│ │
│ │
│ │
│ TPU 0 │
│ │
│ │
│ │
│ │
└───────────────────────┘
np.allclose(jnp.dot(x_single, x_single),
jnp.dot(y, z))
True
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
49.7 ms ± 349 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
7.47 ms ± 44.8 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
即使複製分片的 Array
也會產生具有輸入分片結果
w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
因此,計算遵循資料放置:當我們使用 jax.device_put
明確地分片資料,並將函數應用於該資料時,編譯器會嘗試平行化計算並決定輸出分片。這種分片資料的策略是 JAX 遵循明確裝置放置策略的推廣。
當明確分片不一致時,JAX 會產生錯誤#
但是,如果計算的兩個引數被明確放置在不同的裝置集上,或具有不相容的裝置順序怎麼辦?在這些不明確的情況下,會引發錯誤
import textwrap
from termcolor import colored
def print_exception(e):
name = colored(f'{type(e).__name__}', 'red', force_color=True)
print(textwrap.fill(f'{name}: {str(e)}'))
sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))
sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))
y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x1 of jax.numpy.add with shape int32[24] and
device ids [0, 1, 2, 3] on platform TPU and argument x2 of
jax.numpy.add with shape int32[24] and device ids [4, 5, 6, 7] on
platform TPU
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]
sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))
sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))
y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x1 of jax.numpy.add with shape int32[24] and
device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and argument x2 of
jax.numpy.add with shape int32[24] and device ids [0, 1, 2, 3, 6, 7,
4, 5] on platform TPU
我們說,已使用 jax.device_put
明確放置或分片的陣列會提交到其裝置,因此不會自動移動。有關更多資訊,請參閱裝置放置 FAQ。
當陣列未使用 jax.device_put
明確放置或分片時,它們會未提交地放置在預設裝置上。與已提交的陣列不同,未提交的陣列可以自動移動和重新分片:也就是說,即使其他引數明確放置在不同的裝置上,未提交的陣列也可以作為計算的引數。
例如,jnp.zeros
、jnp.arange
和 jnp.array
的輸出是未提交的
y = jax.device_put(x, sharding1)
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!')
no error!
限制 jit
編譯程式碼中中間值的分片#
雖然編譯器會嘗試決定函數的中間值和輸出應如何分片,但我們也可以使用 jax.lax.with_sharding_constraint
給它提示。使用 jax.lax.with_sharding_constraint
非常像 jax.device_put
,除了我們在階段性輸出 (即 jit
裝飾) 函數內部使用它
mesh = jax.make_mesh((4, 2), ('x', 'y'))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
@jax.jit
def f(x):
x = x + 1
y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))
return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
┌───────┬───────┬───────┬───────┐ │ │ │ │ │ │ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │ │ │ │ │ │ │ │ │ │ │ ├───────┼───────┼───────┼───────┤ │ │ │ │ │ │ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │ │ │ │ │ │ │ │ │ │ │ └───────┴───────┴───────┴───────┘
@jax.jit
def f(x):
x = x + 1
y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
┌───────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3,4,5,6,7 │ │ │ │ │ │ │ │ │ └───────────────────────┘
透過新增 with_sharding_constraint
,我們限制了輸出的分片。除了尊重特定中間值的註解之外,編譯器還將使用註解來決定其他值的分片。
通常,註解計算的輸出是一個好習慣,例如基於最終如何使用這些值。
範例:神經網路#
⚠️ 警告:以下內容旨在簡單示範使用 jax.Array
的自動分片傳播,但可能無法反映真實範例的最佳實務。 例如,真實範例可能需要更多使用 with_sharding_constraint
。
我們可以使用 jax.device_put
和 jax.jit
的 computation-follows-sharding 功能來平行化神經網路中的計算。以下是一些簡單的範例,基於這個基本神經網路
import jax
import jax.numpy as jnp
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.maximum(outputs, 0)
return outputs
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))
def init_layer(key, n_in, n_out):
k1, k2 = jax.random.split(key)
W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
b = jax.random.normal(k2, (n_out,))
return W, b
def init_model(key, layer_sizes, batch_size):
key, *keys = jax.random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
key, *keys = jax.random.split(key, 3)
inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
return params, (inputs, targets)
layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192
params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
8 路批次資料平行處理#
mesh = jax.make_mesh((8,), ('batch',))
sharding = NamedSharding(mesh, P('batch'))
replicated_sharding = NamedSharding(mesh, P())
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, replicated_sharding)
loss_jit(params, batch)
Array(23.469475, dtype=float32)
step_size = 1e-5
for _ in range(30):
grads = gradfun(params, batch)
params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.760109
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
53.8 ms ± 1.14 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
351 ms ± 81.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
4 路批次資料平行處理與 2 路模型張量平行處理#
mesh = jax.make_mesh((4, 2), ('batch', 'model'))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])
┌───────┐ │TPU 0,1│ ├───────┤ │TPU 2,3│ ├───────┤ │TPU 6,7│ ├───────┤ │TPU 4,5│ └───────┘
┌───────┐ │TPU 0,1│ ├───────┤ │TPU 2,3│ ├───────┤ │TPU 6,7│ ├───────┤ │TPU 4,5│ └───────┘
replicated_sharding = NamedSharding(mesh, P())
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
W1 = jax.device_put(W1, replicated_sharding)
b1 = jax.device_put(b1, replicated_sharding)
W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))
b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))
W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))
b3 = jax.device_put(b3, replicated_sharding)
W4 = jax.device_put(W4, replicated_sharding)
b4 = jax.device_put(b4, replicated_sharding)
params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
jax.debug.visualize_array_sharding(W2)
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
jax.debug.visualize_array_sharding(W3)
┌───────────────────────┐ │ │ │ TPU 0,2,4,6 │ │ │ │ │ ├───────────────────────┤ │ │ │ TPU 1,3,5,7 │ │ │ │ │ └───────────────────────┘
print(loss_jit(params, batch))
10.760109
step_size = 1e-5
for _ in range(30):
grads = gradfun(params, batch)
params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.752513
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
┌───────────────────────┐ │ │ │ TPU 0,2,4,6 │ │ │ │ │ ├───────────────────────┤ │ │ │ TPU 1,3,5,7 │ │ │ │ │ └───────────────────────┘
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
51.4 ms ± 454 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)