分散式陣列與自動平行化#

Open in Colab Open in Kaggle

本教學討論透過 jax.Array 實現的平行化,這是 JAX v0.4.1 及更新版本中可用的統一陣列物件模型。

from typing import Optional

import numpy as np

import jax
import jax.numpy as jnp

⚠️ 警告:此筆記本需要 8 個裝置才能執行。

if len(jax.local_devices()) < 8:
  raise Exception("Notebook requires 8 devices to run")

簡介與快速範例#

透過閱讀本教學筆記本,您將了解 jax.Array,這是一種用於表示陣列的統一資料類型,即使其實體儲存跨越多個裝置。您也將了解如何將 jax.Arrayjax.jit 一起使用,以提供基於編譯器的自動平行化。

在我們逐步思考之前,這裡有一個快速範例。首先,我們將建立一個跨多個裝置分片的 jax.Array

from jax.sharding import PartitionSpec as P, NamedSharding
# Create a Sharding object to distribute a value across devices:
mesh = jax.make_mesh((4, 2), ('x', 'y'))
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

接下來,我們將對其應用計算,並視覺化結果值也如何跨多個裝置儲存

z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

jnp.sin 應用的評估已自動在輸入值 (和輸出值) 儲存的裝置上平行化

# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
The slowest run took 8.96 times longer than the fastest. This could mean that an intermediate result is being cached.
25.2 ms ± 30.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
2.4 ms ± 61.4 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)

現在讓我們更詳細地查看每個部分!

Sharding 描述陣列值在跨裝置記憶體中的佈局方式#

Sharding 基礎知識,以及 NamedSharding 子類別#

為了跨多個裝置平行化計算,我們首先必須跨多個裝置佈局輸入資料。

在 JAX 中,Sharding 物件描述分散式記憶體佈局。它們可以與 jax.device_put 一起使用,以產生具有分散式佈局的值。

例如,這是一個具有單裝置 Sharding 的值

import jax
x = jax.random.normal(jax.random.key(0), (8192, 8192))
jax.debug.visualize_array_sharding(x)
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│         TPU 0         │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘

在這裡,我們使用 jax.debug.visualize_array_sharding 函數來顯示值 x 儲存在記憶體中的位置。x 的所有內容都儲存在單一裝置上,因此視覺化非常無聊!

但是我們可以透過使用 jax.device_putSharding 物件來跨多個裝置分片 x。首先,我們使用 jax.make_mesh 建立 Devicesnumpy.ndarray,這會將硬體拓撲納入考量以決定 Device 的順序

from jax.sharding import Mesh, PartitionSpec, NamedSharding

P = PartitionSpec

mesh = jax.make_mesh((4, 2), ('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

我們可以定義一個輔助函數來簡化操作

default_mesh = jax.make_mesh((4, 2), ('a', 'b'))

def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec)
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

在這裡,我們使用 P('a', 'b') 來表示 x 的第一軸和第二軸應分別在裝置網格軸 'a''b' 上分片。我們可以輕鬆切換到 P('b', 'a'),以在不同裝置上分片 x 的軸

y = jax.device_put(x, mesh_sharding(P('b', 'a')))
jax.debug.visualize_array_sharding(y)
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │
│       │       │       │       │
│       │       │       │       │
├───────┼───────┼───────┼───────┤
│       │       │       │       │
│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘
# This `None` means that `x` is not sharded on its second dimension,
# and since the Mesh axis name 'b' is not mentioned, shards are
# replicated across it.
y = jax.device_put(x, mesh_sharding(P('a', None)))
jax.debug.visualize_array_sharding(y)
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘

在這裡,因為 P('a', None) 沒有提及 Mesh 軸名稱 'b',所以我們在軸 'b' 上獲得複製。None 在這裡僅作為佔位符,以對齊值 x 的第二軸,而沒有表達在任何網格軸上的分片。(作為簡寫,可以省略尾隨的 None,因此 P('a', None)P('a') 意思相同。但明確一點也無妨!)

為了僅在 x 的第二軸上分片,我們可以在 PartitionSpec 中使用 None 佔位符

y = jax.device_put(x, mesh_sharding(P(None, 'b')))
jax.debug.visualize_array_sharding(y)
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
y = jax.device_put(x, mesh_sharding(P(None, 'a')))
jax.debug.visualize_array_sharding(y)
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘

對於固定的網格,我們甚至可以將 x 的一個邏輯軸劃分到多個裝置網格軸上

y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)
┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 3         │
├───────────────────────┤
│         TPU 6         │
├───────────────────────┤
│         TPU 7         │
├───────────────────────┤
│         TPU 4         │
├───────────────────────┤
│         TPU 5         │
└───────────────────────┘

使用 NamedSharding 可以輕鬆定義裝置網格一次並給其軸命名,然後在每個 device_putPartitionSpec 中根據需要引用這些名稱。

計算遵循資料分片並自動平行化#

透過分片的輸入資料,編譯器可以為我們提供平行計算。特別是,使用 jax.jit 裝飾的函數可以在分片陣列上操作,而無需將資料複製到單一裝置上。相反,計算遵循分片:根據輸入資料的分片,編譯器決定中間值和輸出值的分片,並平行化其評估,甚至在必要時插入通訊操作。

例如,最簡單的計算是元素級計算

mesh = jax.make_mesh((4, 2), ('a', 'b'))
x = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
print('input sharding:')
jax.debug.visualize_array_sharding(x)

y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y)
input sharding:
output sharding:
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

在這裡,對於元素級操作 jnp.sin,編譯器選擇的輸出分片與輸入相同。此外,編譯器自動平行化計算,以便每個裝置平行地從其輸入分片計算其輸出分片。

換句話說,即使我們編寫 jnp.sin 計算時,彷彿單一機器要執行它,編譯器也會為我們分割計算並在多個裝置上執行它。

我們也可以對不僅僅是元素級操作執行相同的操作。考慮具有分片輸入的矩陣乘法

y = jax.device_put(x, NamedSharding(mesh, P('a', None)))
z = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)

w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)
lhs sharding:
rhs sharding:
out sharding:
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

在這裡,編譯器選擇輸出分片,以便它可以最大程度地平行化計算:無需通訊,每個裝置都已擁有計算其輸出分片所需的輸入分片。

我們如何確定它實際上是在平行運行?我們可以進行一個簡單的計時實驗

x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│         TPU 0         │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘
np.allclose(jnp.dot(x_single, x_single),
            jnp.dot(y, z))
True
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
49.7 ms ± 349 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
7.47 ms ± 44.8 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)

即使複製分片的 Array 也會產生具有輸入分片結果

w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘

因此,計算遵循資料放置:當我們使用 jax.device_put 明確地分片資料,並將函數應用於該資料時,編譯器會嘗試平行化計算並決定輸出分片。這種分片資料的策略是 JAX 遵循明確裝置放置策略的推廣。

當明確分片不一致時,JAX 會產生錯誤#

但是,如果計算的兩個引數被明確放置在不同的裝置集上,或具有不相容的裝置順序怎麼辦?在這些不明確的情況下,會引發錯誤

import textwrap
from termcolor import colored

def print_exception(e):
  name = colored(f'{type(e).__name__}', 'red', force_color=True)
  print(textwrap.fill(f'{name}: {str(e)}'))
sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))
sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))

y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x1 of jax.numpy.add with shape int32[24] and
device ids [0, 1, 2, 3] on platform TPU and argument x2 of
jax.numpy.add with shape int32[24] and device ids [4, 5, 6, 7] on
platform TPU
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]

sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))
sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))

y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x1 of jax.numpy.add with shape int32[24] and
device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and argument x2 of
jax.numpy.add with shape int32[24] and device ids [0, 1, 2, 3, 6, 7,
4, 5] on platform TPU

我們說,已使用 jax.device_put 明確放置或分片的陣列會提交到其裝置,因此不會自動移動。有關更多資訊,請參閱裝置放置 FAQ

當陣列使用 jax.device_put 明確放置或分片時,它們會未提交地放置在預設裝置上。與已提交的陣列不同,未提交的陣列可以自動移動和重新分片:也就是說,即使其他引數明確放置在不同的裝置上,未提交的陣列也可以作為計算的引數。

例如,jnp.zerosjnp.arangejnp.array 的輸出是未提交的

y = jax.device_put(x, sharding1)
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!')
no error!

限制 jit 編譯程式碼中中間值的分片#

雖然編譯器會嘗試決定函數的中間值和輸出應如何分片,但我們也可以使用 jax.lax.with_sharding_constraint 給它提示。使用 jax.lax.with_sharding_constraint 非常像 jax.device_put,除了我們在階段性輸出 (即 jit 裝飾) 函數內部使用它

mesh = jax.make_mesh((4, 2), ('x', 'y'))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))
  return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │
│       │       │       │       │
│       │       │       │       │
├───────┼───────┼───────┼───────┤
│       │       │       │       │
│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
  return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│  TPU 0,1,2,3,4,5,6,7  │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘

透過新增 with_sharding_constraint,我們限制了輸出的分片。除了尊重特定中間值的註解之外,編譯器還將使用註解來決定其他值的分片。

通常,註解計算的輸出是一個好習慣,例如基於最終如何使用這些值。

範例:神經網路#

⚠️ 警告:以下內容旨在簡單示範使用 jax.Array 的自動分片傳播,但可能無法反映真實範例的最佳實務。 例如,真實範例可能需要更多使用 with_sharding_constraint

我們可以使用 jax.device_putjax.jit 的 computation-follows-sharding 功能來平行化神經網路中的計算。以下是一些簡單的範例,基於這個基本神經網路

import jax
import jax.numpy as jnp
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.maximum(outputs, 0)
  return outputs

def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))
def init_layer(key, n_in, n_out):
  k1, k2 = jax.random.split(key)
  W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
  b = jax.random.normal(k2, (n_out,))
  return W, b

def init_model(key, layer_sizes, batch_size):
  key, *keys = jax.random.split(key, len(layer_sizes))
  params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))

  key, *keys = jax.random.split(key, 3)
  inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
  targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))

  return params, (inputs, targets)

layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192

params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)

8 路批次資料平行處理#

mesh = jax.make_mesh((8,), ('batch',))
sharding = NamedSharding(mesh, P('batch'))
replicated_sharding = NamedSharding(mesh, P())
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, replicated_sharding)
loss_jit(params, batch)
Array(23.469475, dtype=float32)
step_size = 1e-5

for _ in range(30):
  grads = gradfun(params, batch)
  params = [(W - step_size * dW, b - step_size * db)
            for (W, b), (dW, db) in zip(params, grads)]

print(loss_jit(params, batch))
10.760109
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
53.8 ms ± 1.14 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
351 ms ± 81.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)

4 路批次資料平行處理與 2 路模型張量平行處理#

mesh = jax.make_mesh((4, 2), ('batch', 'model'))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])
┌───────┐
│TPU 0,1│
├───────┤
│TPU 2,3│
├───────┤
│TPU 6,7│
├───────┤
│TPU 4,5│
└───────┘
┌───────┐
│TPU 0,1│
├───────┤
│TPU 2,3│
├───────┤
│TPU 6,7│
├───────┤
│TPU 4,5│
└───────┘
replicated_sharding = NamedSharding(mesh, P())
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params

W1 = jax.device_put(W1, replicated_sharding)
b1 = jax.device_put(b1, replicated_sharding)

W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))
b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))

W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))
b3 = jax.device_put(b3, replicated_sharding)

W4 = jax.device_put(W4, replicated_sharding)
b4 = jax.device_put(b4, replicated_sharding)

params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
jax.debug.visualize_array_sharding(W2)
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
jax.debug.visualize_array_sharding(W3)
┌───────────────────────┐
│                       │
│      TPU 0,2,4,6      │
│                       │
│                       │
├───────────────────────┤
│                       │
│      TPU 1,3,5,7      │
│                       │
│                       │
└───────────────────────┘
print(loss_jit(params, batch))
10.760109
step_size = 1e-5

for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.752513
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
┌───────────────────────┐
│                       │
│      TPU 0,2,4,6      │
│                       │
│                       │
├───────────────────────┤
│                       │
│      TPU 1,3,5,7      │
│                       │
│                       │
└───────────────────────┘
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
51.4 ms ± 454 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)

尖銳的重點#

產生隨機數#

JAX 帶有一個功能性、確定性的 隨機數產生器。它是 jax.random 模組中各種取樣函數 (例如 jax.random.uniform) 的基礎。

JAX 的隨機數是由基於計數器的 PRNG 產生,因此原則上,隨機數產生應該是對計數器值的純映射。原則上,純映射是一種微不足道的可分割操作。它應該不需要跨裝置通訊,也不需要在裝置之間進行任何冗餘計算。

然而,由於歷史原因,現有的穩定 RNG 實作並非自動可分割的。

考慮以下範例,其中一個函數繪製隨機均勻數並將其元素級地新增到輸入

@jax.jit
def f(key, x):
  numbers = jax.random.uniform(key, x.shape)
  return x + numbers

key = jax.random.key(42)
mesh = Mesh(jax.devices(), 'x')
x_sharding = NamedSharding(mesh, P('x'))
x = jax.device_put(jnp.arange(24), x_sharding)

在分片的輸入上,函數 f 產生也是分片的輸出

jax.debug.visualize_array_sharding(f(key, x))
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

但是,如果我們檢查在此分片輸入上針對 f 的編譯計算,我們會看到它確實涉及一些通訊

f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? True

一種解決方法是使用實驗性升級標誌 jax_threefry_partitionable 配置 JAX。啟用此標誌後,“collective permute” 操作現在已從編譯計算中消失

jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? False

輸出仍然是分片的

jax.debug.visualize_array_sharding(f(key, x))
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

然而,jax_threefry_partitionable 選項的一個注意事項是,即使隨機值是由相同的隨機金鑰產生,產生的隨機值也可能與未設定標誌時不同

jax.config.update('jax_threefry_partitionable', False)
print('Stable:')
print(f(key, x))
print()

jax.config.update('jax_threefry_partitionable', True)
print('Partitionable:')
print(f(key, x))
Stable:
[ 0.72503686  1.8532515   2.983416    3.083253    4.0332246   5.4782867
  6.1720605   7.6900277   8.602836    9.810046   10.861367   11.907651
 12.330483   13.456195   14.808557   15.960099   16.067581   17.739723
 18.335474   19.46401    20.390276   21.116539   22.858128   23.223194  ]

Partitionable:
[ 0.48870957  1.6797972   2.6162715   3.561016    4.4506445   5.585866
  6.0748096   7.775133    8.698959    9.818634   10.350306   11.87282
 12.925881   13.86013    14.477554   15.818481   16.711355   17.586697
 18.073738   19.777622   20.404566   21.119123   22.026257   23.63918   ]

jax_threefry_partitionable 模式下,JAX PRNG 仍然是確定性的,但其實作是新的 (且正在開發中)。針對給定金鑰產生的隨機值在給定的 JAX 版本 (或 main 分支上的給定提交) 中將是相同的,但可能因發行版本而異。