jax.experimental.custom_partitioning
模組#
API#
- jax.experimental.custom_partitioning.custom_partitioning(fun, static_argnums=())[原始碼]#
將 CustomCallOp 插入具有自訂 SPMD 降低規則的 XLA 圖中。
@custom_partitioning def f(*args): return ... def propagate_user_sharding(mesh, user_shape): '''Update the sharding of the op from a user's shape.sharding.''' user_sharding = jax.tree.map(lambda x: x.sharding, user_shape) def partition(mesh, arg_shapes, result_shape): def lower_fn(*args): ... builds computation on per-device shapes ... result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) # result_sharding and arg_shardings may optionally be modified and the # partitioner will insert collectives to reshape. return mesh, lower_fn, result_sharding, arg_shardings def infer_sharding_from_operands(mesh, arg_shapes, shape): '''Compute the result sharding from the sharding of the operands.''' arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands=infer_sharding_from_operands, sharding_rule='i j -> 'i j') When config.use_shardy_partitioner.value is True, the sharding_rule is used; otherwise, propagate_user_sharding and infer_sharding_from_operands are used. Instead of using an Einsum-like notation string, sharding_rule can also be a SdyShardingRule object, such as sharding_rule=SdyShardingRule(("i", "j"), ("i", "j")).
def_partition
的引數如下propagate_user_sharding
:可呼叫物件,接受使用者(在 dag 中)的分片,並傳回新的 NamedSharding 的建議。預設實作只是傳回建議的分片。partition
:可呼叫物件,接受 SPMD 建議的分區形狀和分區規格,並傳回網格、每個分片的降低函式,以及最終的輸入和輸出分片規格(SPMD 分區器將重新分區輸入以符合)。傳回網格是為了在未提供網格時允許設定集合運算的 axis_names。infer_sharding_from_operands
:可呼叫物件,從為每個引數選擇的NamedSharding
計算輸出NamedSharding
。decode_shardings
:設定為 True 時,如果可能,將輸入GSPMDSharding
轉換為NamedSharding
。如果使用者未提供上下文網格,則可能無法轉換。sharding_rule
:SdyShardingRule 物件或類似 Einsum 符號的字串,描述分片規則。我們借鑒了 einops.rearrange 字串的想法,在因子之間使用空格分隔符,並允許使用多個字母的因子名稱。
可以使用 static_argnums 將位置引數指定為靜態。JAX 使用
inspect.signature(fun)
來解析這些位置引數。範例
舉例來說,假設我們要增強現有的
jax.numpy.fft.fft
。此函式計算 N 維輸入沿最後一個維度的離散傅立葉變換,並沿前 N-1 個維度進行批次處理。但是,預設情況下,它會忽略輸入的分片,並在所有裝置上收集輸入。但是,由於jax.numpy.fft.fft
沿前 N-1 個維度進行批次處理,因此這是沒有必要的。我們將建立一個新的my_fft
運算,它不會改變沿前 N-1 個維度的分片,並且僅在需要時沿最後一個維度收集輸入。import jax from jax.sharding import NamedSharding from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from jax.sharding import Mesh from jax.numpy.fft import fft import regex as re import numpy as np # Pattern to detect all-gather or dynamic-slice in the generated HLO _PATTERN = '(dynamic-slice|all-gather)' # For an N-D input, keeps sharding along the first N-1 dimensions # but replicate along the last dimension def supported_sharding(sharding, shape): rank = len(shape.shape) max_shared_dims = min(len(sharding.spec), rank-1) names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims)) return NamedSharding(sharding.mesh, P(*names)) def partition(mesh, arg_shapes, result_shape): result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) return mesh, fft, supported_sharding(arg_shardings[0], arg_shapes[0]), (supported_sharding(arg_shardings[0], arg_shapes[0]),) def infer_sharding_from_operands(mesh, arg_shapes, result_shape): arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) return supported_sharding(arg_shardings[0], arg_shapes[0]) @custom_partitioning def my_fft(x): return fft(x) # Use Einsum-like notation to specify the sharding rule. my_fft.def_partition( infer_sharding_from_operands=infer_sharding_from_operands, partition=partition, sharding_rule='...i -> ...i') # Use SdyShardingRule object to specify the sharding rule. my_fft.def_partition( infer_sharding_from_operands=infer_sharding_from_operands, partition=partition, sharding_rule=SdyShardingRule(operand_mappings=((SDY_BATCHING, 'i'),), result_mappings=((SDY_BATCHING, 'i'),))))
現在建立一個沿第一個軸分片的 2D 陣列,將其傳遞到
my_fft
,並注意它仍然按照預期進行分片,並且與fft
的輸出相同。但是,檢查 HLO(使用lower(x).compile().runtime_executable().hlo_modules()
)會發現my_fft
沒有建立任何 all-gather 或 dynamic-slice,而fft
則有。with Mesh(np.array(jax.devices()), ('x',)): x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64) y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) print(pjit_my_fft(y)) print(pjit_fft(y)) # dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) # dynamic-slice or all-gather are present in the HLO for fft assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
# my_fft [[-38.840824 +0.j -40.649452 +11.845365j ... -1.6937828 +0.8402481j 15.999859 -4.0156755j]] # jax.numpy.fft.fft [[-38.840824 +0.j -40.649452 +11.845365j ... -1.6937828 +0.8402481j 15.999859 -4.0156755j]]
由於
supported_sharding
中的邏輯,my_fft
也適用於一維陣列。但是,在這種情況下,my_fft
的 HLO 確實顯示了 dynamic-slice,因為最後一個維度是計算 FFT 的維度,並且需要在執行計算之前在所有裝置上複製。with Mesh(np.array(jax.devices()), ('x',)): x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64) y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) print(pjit_my_fft(y)) print(pjit_fft(y)) # dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) # dynamic-slice or all-gather are present in the HLO for fft assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
# my_fft [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j ... 1422.4502 +7271.4297j -405.84033 -3042.983j -3012.4963 -4287.6343j] # jax.numpy.fft.fft [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j ... 1422.4502 +7271.4297j -405.84033 -3042.983j -3012.4963 -4287.6343j]