jax.experimental.custom_partitioning 模組#

API#

jax.experimental.custom_partitioning.custom_partitioning(fun, static_argnums=())[原始碼]#

將 CustomCallOp 插入具有自訂 SPMD 降低規則的 XLA 圖中。

@custom_partitioning
def f(*args):
  return ...

def propagate_user_sharding(mesh, user_shape):
  '''Update the sharding of the op from a user's shape.sharding.'''
  user_sharding = jax.tree.map(lambda x: x.sharding, user_shape)

def partition(mesh, arg_shapes, result_shape):
  def lower_fn(*args):
    ... builds computation on per-device shapes ...
  result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
  arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
  # result_sharding and arg_shardings may optionally be modified and the
  # partitioner will insert collectives to reshape.
  return mesh, lower_fn, result_sharding, arg_shardings

def infer_sharding_from_operands(mesh, arg_shapes, shape):
  '''Compute the result sharding from the sharding of the operands.'''
  arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)


f.def_partition(partition, propagate_user_sharding,
                infer_sharding_from_operands=infer_sharding_from_operands,
                sharding_rule='i j -> 'i j')
When config.use_shardy_partitioner.value is True, the sharding_rule is
used; otherwise, propagate_user_sharding and infer_sharding_from_operands
are used.
Instead of using an Einsum-like notation string, sharding_rule can also be
a SdyShardingRule object, such as sharding_rule=SdyShardingRule(("i", "j"), ("i", "j")).

def_partition 的引數如下

  • propagate_user_sharding:可呼叫物件,接受使用者(在 dag 中)的分片,並傳回新的 NamedSharding 的建議。預設實作只是傳回建議的分片。

  • partition:可呼叫物件,接受 SPMD 建議的分區形狀和分區規格,並傳回網格、每個分片的降低函式,以及最終的輸入和輸出分片規格(SPMD 分區器將重新分區輸入以符合)。傳回網格是為了在未提供網格時允許設定集合運算的 axis_names。

  • infer_sharding_from_operands:可呼叫物件,從為每個引數選擇的 NamedSharding 計算輸出 NamedSharding

  • decode_shardings:設定為 True 時,如果可能,將輸入 GSPMDSharding 轉換為 NamedSharding。如果使用者未提供上下文網格,則可能無法轉換。

  • sharding_rule:SdyShardingRule 物件或類似 Einsum 符號的字串,描述分片規則。我們借鑒了 einops.rearrange 字串的想法,在因子之間使用空格分隔符,並允許使用多個字母的因子名稱。

可以使用 static_argnums 將位置引數指定為靜態。JAX 使用 inspect.signature(fun) 來解析這些位置引數。

範例

舉例來說,假設我們要增強現有的 jax.numpy.fft.fft。此函式計算 N 維輸入沿最後一個維度的離散傅立葉變換,並沿前 N-1 個維度進行批次處理。但是,預設情況下,它會忽略輸入的分片,並在所有裝置上收集輸入。但是,由於 jax.numpy.fft.fft 沿前 N-1 個維度進行批次處理,因此這是沒有必要的。我們將建立一個新的 my_fft 運算,它不會改變沿前 N-1 個維度的分片,並且僅在需要時沿最後一個維度收集輸入。

import jax
from jax.sharding import NamedSharding
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
from jax.numpy.fft import fft
import regex as re
import numpy as np

# Pattern to detect all-gather or dynamic-slice in the generated HLO
_PATTERN = '(dynamic-slice|all-gather)'

# For an N-D input, keeps sharding along the first N-1 dimensions
# but replicate along the last dimension
def supported_sharding(sharding, shape):
    rank = len(shape.shape)
    max_shared_dims = min(len(sharding.spec), rank-1)
    names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims))
    return NamedSharding(sharding.mesh, P(*names))

def partition(mesh, arg_shapes, result_shape):
    result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
    arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
    return mesh, fft,               supported_sharding(arg_shardings[0], arg_shapes[0]),               (supported_sharding(arg_shardings[0], arg_shapes[0]),)

def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
    arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
    return supported_sharding(arg_shardings[0], arg_shapes[0])

@custom_partitioning
def my_fft(x):
    return fft(x)

# Use Einsum-like notation to specify the sharding rule.
my_fft.def_partition(
  infer_sharding_from_operands=infer_sharding_from_operands,
  partition=partition,
  sharding_rule='...i -> ...i')
# Use SdyShardingRule object to specify the sharding rule.
my_fft.def_partition(
  infer_sharding_from_operands=infer_sharding_from_operands,
  partition=partition,
  sharding_rule=SdyShardingRule(operand_mappings=((SDY_BATCHING, 'i'),), result_mappings=((SDY_BATCHING, 'i'),))))

現在建立一個沿第一個軸分片的 2D 陣列,將其傳遞到 my_fft,並注意它仍然按照預期進行分片,並且與 fft 的輸出相同。但是,檢查 HLO(使用 lower(x).compile().runtime_executable().hlo_modules())會發現 my_fft 沒有建立任何 all-gather 或 dynamic-slice,而 fft 則有。

with Mesh(np.array(jax.devices()), ('x',)):
  x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64)
  y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x)
  pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x'))
  pjit_fft    = pjit(fft,    in_shardings=P('x'), out_shardings=P('x'))
  print(pjit_my_fft(y))
  print(pjit_fft(y))
  # dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array
  assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
  # dynamic-slice or all-gather are present in the HLO for fft
  assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())    is not None)
# my_fft
[[-38.840824   +0.j        -40.649452  +11.845365j
...
  -1.6937828  +0.8402481j  15.999859   -4.0156755j]]

# jax.numpy.fft.fft
[[-38.840824   +0.j        -40.649452  +11.845365j
  ...
  -1.6937828  +0.8402481j  15.999859   -4.0156755j]]

由於 supported_sharding 中的邏輯,my_fft 也適用於一維陣列。但是,在這種情況下,my_fft 的 HLO 確實顯示了 dynamic-slice,因為最後一個維度是計算 FFT 的維度,並且需要在執行計算之前在所有裝置上複製。

with Mesh(np.array(jax.devices()), ('x',)):
  x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64)
  y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x)
  pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x'))
  pjit_fft    = pjit(fft,    in_shardings=P('x'), out_shardings=P('x'))
  print(pjit_my_fft(y))
  print(pjit_fft(y))
  # dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array
  assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
  # dynamic-slice or all-gather are present in the HLO for fft
  assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())    is not None)
# my_fft
[    7.217285   +0.j     -3012.4937  +4287.635j   -405.83594 +3042.984j
...  1422.4502  +7271.4297j  -405.84033 -3042.983j
-3012.4963  -4287.6343j]

# jax.numpy.fft.fft
[    7.217285   +0.j     -3012.4937  +4287.635j   -405.83594 +3042.984j
...  1422.4502  +7271.4297j  -405.84033 -3042.983j
-3012.4963  -4287.6343j]