shmap (shard_map) for simple per-device code#

sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@

2023 年 1 月

這是提議 shard_map 的設計文件。您可能反而想要最新的使用者文件

動機#

JAX 支援多裝置程式設計的兩種思路

  1. 編譯器,接手方向盤! 讓編譯器自動將 bulk 陣列函式分割到裝置上。

  2. 就讓我寫出我的意思,可惡! 給我 per-device 程式碼和顯式通訊集合體 (explicit communication collectives)。

我們需要適用於兩者的出色 API,而且它們不應該是互斥的替代方案,它們需要彼此組合。

使用 pjit(現在只是 jit),我們為第一種思路提供了下一代 API。但是我們還沒有完全提升第二種思路。pmap 遵循第二種思路,但隨著時間推移,我們發現它有致命的缺陷xmap 解決了這些缺陷,但它並沒有完全給我們 per-device 形狀,而且它還包含其他幾個大概念。同時,對於 per-device 顯式集合體程式設計的新需求已經出現,例如在Efficiently Scaling Transformer Inference中。

我們可以使用 shmap 來提升第二種思路。shmap

  • 一個簡單的多裝置平行化 API,讓我們可以使用顯式集合體編寫 per-device 程式碼,其中邏輯形狀與 per-device 物理緩衝區形狀相符,並且集合體與跨裝置通訊完全對應;

  • xmap 的特化版本,具有縮減的功能和一些調整;

  • XLA SPMD Partitioner 的「手動」模式的相當直接的呈現;

  • 一個有趣發音的蘇斯式名稱,可以代表 shard_mapshpecialized_xmapsholto_mapsharad_map

對於 pjit 使用者shmap 是一個互補工具。它可以在 pjit 計算內部使用,以暫時進入「手動集合體」模式,就像編譯器的自動分割的逃生出口。這樣,使用者可以在他們的大部分程式碼中獲得 pjit 的便利性和熟悉的 just-NumPy 程式設計模型,以及在需要時使用 shmap 手動最佳化集合體通訊的能力。這是兩全其美!

對於 pmap 使用者shmap 是一個嚴格的升級。它更具表現力、效能更好,並且可以與其他 JAX API 組合,而不會使基本批次資料平行化變得更困難。

有關實際使用的更多資訊,您可以跳到何時應該使用 shmap,何時應該使用 pjit。如果您想知道為什麼我們需要一個新事物,或者 pmap 有什麼問題,請跳到為什麼 pmapxmap 還不能解決這個問題?。或繼續閱讀下一節,以查看一些 shmap 範例和 API 規格。

那麼,讓我們看看 shmap 吧!#

TL;DR 範例(附帶更詳細的解釋)#

Sho shick

from functools import partial

import numpy as np

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map

mesh = jax.make_mesh((4, 2), ('i', 'j'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)

@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
         out_specs=P('i', None))
def matmul_basic(a_block, b_block):
  # a_block: f32[2, 8]
  # b_block: f32[8, 32]
  z_partialsum = jnp.dot(a_block, b_block)
  z_block = jax.lax.psum(z_partialsum, 'j')
  return z_block

c = matmul_basic(a, b)  # c: f32[8, 32]

請注意

  • pmap 不同,多個平行化軸不需要巢狀結構(或 axis_index_groups);

  • pmap 和 hard-xmap 不同,呼叫者中沒有 reshapes,並且邏輯形狀對應於 per-device 物理形狀,這與(非 hard)xmap 不同;

  • pmap 不同,使用 mesh 進行精確的裝置放置控制;

  • xmap 不同,邏輯和物理只有一組軸名稱;

  • 結果是一個 jax.Array,它可以有效地傳遞給 pjit,這與 pmap 不同;

  • pmap 不同,相同的程式碼在 pjit/jit 內部也能有效運作;

  • 此程式碼以 eager 方式運作,因此我們可以在中間使用 pdb 並列印數值,這與 xmap 的目前實作不同(雖然在設計上,沒有循序排程的 xmap 原則上也可以 eager 方式運作)。

這是另一個 matmul 變體,具有完全分片的結果

@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
         out_specs=P('i', 'j'))
def matmul_reduce_scatter(a_block, b_block):
  # c_partialsum: f32[8/X, 32]
  c_partialsum = jnp.matmul(a_block, b_block)
  # c_block: f32[8/X, 32/Y]
  c_block = jax.lax.psum_scatter(c_partialsum, 'j', scatter_dimension=1, tiled=True)
  return c_block

c = matmul_reduce_scatter(a, b)

慢下來,從基礎開始!#

陣列軸的 rank-reducing 與 rank-preserving 映射#

我們可以將 pmap(以及 vmapxmap)視為沿著軸 unstacking 每個陣列輸入(例如,將 2D 矩陣解包成其 1D 列),將其主體函式應用於每個部分,然後將結果堆疊回一起,至少在不涉及集合體時是這樣

pmap(f, in_axes=[0], out_axes=0)(xs) == jnp.stack([f(x) for x in xs])

例如,如果 xs 的形狀為 f32[8,5],則每個 x 的形狀為 f32[5],並且如果每個 f(x) 的形狀為 f32[3,7],則最終堆疊的結果 pmap(f)(xs) 的形狀為 f32[8,3,7]。也就是說,主體函式 f 的每次應用都將軸數比 pmap(f) 的對應引數少一個軸的輸入作為引數。我們可以說這些是rank-reducing 映射,具有輸入/輸出的 unstacking/stacking。

f 的邏輯應用次數由要映射的輸入軸的大小決定:例如,如果我們映射大小為 8 的輸入軸,則在語意上我們得到 8 次函式的邏輯應用,對於 pmap,這始終對應於 8 個裝置在物理上計算它們。

相反,shmap 沒有這種 rank-reducing 行為。相反,我們可以將其視為沿著輸入軸 slicing(或「unconcatenating」)成區塊,應用主體函式,然後將結果串聯回一起(再次在不涉及集合體時)

devices = np.array(jax.devices()[:4])
m = Mesh(devices, ('i',))  # mesh.shape['i'] = 4

shard_map(f, m, in_specs=P('i'), out_specs=P('i'))(y)
==
jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])

回想一下,jnp.split 將其輸入 slicing 成相同大小的區塊,這些區塊具有相同的 rank,因此如果在上面的範例中 y 的形狀為 f32[8,5],則每個 y_blk 的形狀為 f32[2,5],並且如果每個 f(y_blk) 的形狀為 f32[3,7],則最終串聯的結果 shard_map(f, ...)(y) 的形狀為 f32[12,7]。因此,shmap (shard_map) 映射其輸入的分片或區塊。我們可以說它是rank-preserving 映射,具有輸入/輸出的 unconcatenating/concatenating。

f 的邏輯應用次數由 mesh 大小決定,而不是由任何輸入軸大小決定:例如,如果我們有一個總大小為 4 的 mesh(即在 4 個裝置上),則在語意上我們得到 4 次函式的邏輯應用,對應於 4 個裝置在物理上計算它們。

使用 in_specs 控制每個輸入如何分割 (unconcatenated) 和 tiling#

每個 in_specs 都使用 PartitionSpecs 通過名稱識別一些對應輸入陣列的軸與 mesh 軸,表示如何將該輸入分割(或 unconcatenate)成主體函式應用於的區塊。這種識別確定了分片大小;當輸入軸與 mesh 軸識別時,輸入沿著該邏輯軸分割(unconcatenate)成等於對應 mesh 軸大小的塊數。(如果對應的 mesh 軸大小不能均勻地除輸入陣列軸大小,則會發生錯誤。)如果輸入的 pspec 沒有提及 mesh 軸名稱,則不會在該 mesh 軸上進行分割。例如

devices = np.array(jax.devices())
m = Mesh(devices.reshape(4, 2), ('i', 'j'))

@partial(shard_map, mesh=m, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
  print(x_block.shape)
  return x_block

x1 = np.arange(12 * 12).reshape(12, 12)
y = f1(x1)  # prints (3,12)

在這裡,由於輸入 pspec 沒有提及 mesh 軸名稱 'j',因此沒有輸入陣列軸在該 mesh 軸上分割;同樣地,由於輸入陣列的第二個軸未與任何 mesh 軸識別(因此未在其上分割),因此 f1 的應用程式獲得了沿該軸的輸入的完整視圖。

當輸入 pspec 中未提及 mesh 軸時,我們始終可以重寫為效率較低的程式,其中提及了所有 mesh 軸,但呼叫者執行 jnp.tile,例如

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
  print(x_block.shape)
  return x_block

x = np.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.axis_size['j']))  # x_ has shape (12, 24)
y = f2(x_)  # prints (3,12), and f1(x) == f2(x_)

換句話說,由於每個輸入 pspec 可以提及每個 mesh 軸名稱零次或一次,而不是必須精確地提及每個名稱一次,因此我們可以說,除了內建於其輸入中的 jnp.split 之外,shard_map 還具有內建於其輸入中的 jnp.tile,至少在邏輯上是這樣(儘管 tiling 可能不需要物理執行,具體取決於引數的物理分片佈局)。要使用的 tiling 不是唯一的;我們也可以沿著第一個軸 tiling,並使用 pspec P(('j', 'i'), None)

輸入端可能發生物理資料移動,因為每個裝置都需要擁有適當資料的副本。

使用 out_specs 控制如何通過串聯、區塊轉置和 untiling 組裝每個輸出#

與輸入端類似,每個 out_specs 都通過名稱識別一些對應輸出陣列的軸與 mesh 軸,表示應該如何將輸出區塊(主體函式的每次應用一個,或等效地每個物理裝置一個)組裝回一起以形成最終輸出值。例如,在上面的 f1f2 範例中,out_specs 指示我們應該通過沿兩個軸將區塊結果串聯在一起來形成最終輸出,從而在兩種情況下都產生形狀為 (12,24) 的陣列 y。(如果主體函式的輸出形狀(即輸出區塊形狀)的 rank 太小,以至於無法進行對應輸出 pspec 描述的串聯,則會發生錯誤。)

當輸出 pspec 中未提及 mesh 軸名稱時,它表示un-tiling:當使用者編寫未提及 mesh 軸名稱之一的輸出 pspec 時,他們承諾輸出區塊沿該 mesh 軸相等,因此輸出中僅使用沿該軸的一個區塊(而不是沿該 mesh 軸將所有區塊串聯在一起)。例如,使用與上面相同的 mesh

x = jnp.array([[3.]])

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', 'j'))()
print(z)  # prints the same as jnp.tile(x, (4, 2))

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', None))()
print(z)  # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P(None, None))()
print(z)  # prints the same as jnp.tile(x, (1, 1)), or just x

請注意,閉包陣列值的主體函式等效於將其作為帶有 P(None, None) 的對應輸入 pspec 的擴充引數傳遞。作為另一個範例,更密切地遵循上面的其他範例

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
  return jax.lax.psum(x_block, 'j')

x = np.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)  # (12,6)

請注意,結果的第二個軸大小為 6,是輸入第二個軸大小的一半。在這種情況下,由於集合體 psum,通過在輸出 pspec 中不提及 mesh 軸名稱 'j' 表達的 un-tile 是安全的,這確保了每個輸出區塊沿著對應的 mesh 軸相等。以下是另外兩個範例,我們在其中更改了輸出 pspec 中提及的 mesh 軸

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
  return jax.lax.psum(x_block, 'i')

x = np.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape)  # (3,12)


@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
  return jax.lax.psum(x_block, ('i', 'j'))

y5 = f5(x)
print(y5.shape)  # (3,6)

在物理方面,在輸出 pspec 中不提及 mesh 軸名稱會從沿該 mesh 軸具有複製佈局的輸出裝置緩衝區組裝 Array

沒有執行階段檢查來驗證輸出區塊是否確實沿著要 untile 的 mesh 軸相等,或者等效地,對應的物理緩衝區是否具有相等的值,因此可以解釋為單個邏輯陣列的複製佈局。但是我們可以提供一個靜態檢查機制,該機制會在所有可能不正確的程式上引發錯誤。

由於 out_specs 可以提及 mesh 軸名稱零次或一次,並且由於它們可以以任何順序提及,因此我們可以說,除了內建於其輸出中的 jnp.concatenate 之外,shard_map 還具有內建於其輸出中的 untile 和區塊轉置。

無論輸出 pspec 如何,輸出端都不可能發生物理資料移動。相反,out_specs 僅編碼如何將區塊輸出組裝成 Array,或者物理上如何將跨裝置的緩衝區解釋為單個邏輯 Array 的物理佈局。

API 規格#

from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]

def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
          ) -> Callable:
  ...

其中

  • mesh 編碼排列在陣列中且具有關聯軸名稱的裝置,就像它對 xmapsharding.NamedSharding 所做的那樣;

  • in_specsout_specsPartitionSpecs,它們可以仿射地提及來自 mesh 的軸名稱(不是像 xmap 中那樣的單獨邏輯名稱),以分別表示輸入和輸出的 slicing/unconcatenation 和串聯(不是像 pmapxmap 那樣的 unstacking 和 stacking),未提及的名稱分別對應於複製和 untiling(assert-replicated-so-give-me-one-copy);

  • 傳遞給 f 的引數的形狀與傳遞給 shard_map-of-f 的引數具有相同的 rank(與 pmapxmap 不同,後者的 rank 會降低),並且 f 的引數的形狀是根據 shard_map-of-f 的對應引數的形狀 shape 和對應的 PartitionSpec 規格計算得出的,大致為 tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))

  • f 的主體可以使用來自 mesh 的名稱應用集合體。

shmap 預設為 eager 模式,這表示我們逐個 primitive 調度計算,以便使用者可以在完全複製的值上使用 Python 控制流程和互動式 pdb 除錯來列印任何值。要 staged out 並端對端編譯 shmap 函式,只需在其周圍放置一個 jit。一個結果是,shmap 沒有像 xmappmap 目前那樣的自己的調度和編譯路徑;它只是 jit 路徑。

當通過例如封閉的 jit 進行 staged out 時,將 shmap 降低到 StableHLO 的過程很簡單:它僅涉及在輸入端切換到「手動 SPMD 模式」,然後在輸出端切換回來。(我們目前不打算支援部分手動部分自動模式。)

與效果的交互作用與 pmap 相同。

與自動微分的互動也就像 pmap 一樣(而不是嘗試 xmap 所做的新語義,對應於擁有未映射的中介值,因此 gradreduce_axes 以及使 psum 轉置為 pbroadcast 而不是 psum)。但因此它也繼承了 pmap 的一個未解決的問題:在某些情況下,與其將 psum 轉置為 psum,從而在反向傳播中執行與正向傳播 psum 對應的 psum,不如將反向傳播 psum 移動到反向傳播中的其他位置,利用線性性質可能更有益。許多進階 pmap 使用者透過使用 custom_vjp 來實作 psum_idrevid_psumrev 函數來解決這個挑戰,但由於很容易不小心讓它們不平衡,因此這種技術就像是個土製大砲。我們對於如何以更安全的方式提供此功能有一些想法。

何時應該使用 shmap,又何時應該使用 pjit 呢?#

一種哲學是:幾乎總是使用 jit==pjit 編寫程式更簡單 — 但如果程式的某個部分編譯器最佳化的程度不如它可能達到的程度,那就改用 shmap

一個實際的範例#

以下是在具有 2D 權重收集模式的 Transformer 層傳遞中,shmap 可能看起來的樣子(論文,第 5 頁的 3.2.3 節)

def matmul_2D_wg_manual(xnorm, q_wi, layer):
  '''Calls a custom manual implementation of matmul_reducescatter'''
  # [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]
  # -> (matmul)
  # -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}
  # -> (reducescatter over x into X heads, B batches)
  # -> [batch, maxlen, heads.YZX, q_wi_per_head]
  with jax.named_scope('q_wi'):
    xnorm = intermediate_dtype(xnorm)
    q_wi = matmul_reducescatter(
        'bte,hed->bthd',
        xnorm,
        params.q_wi,
        scatter_dimension=(0, 2),
        axis_name='i',
        layer=layer)
   return q_wi


import partitioning.logical_to_physical as l2phys

def pjit_transformer_layer(
    hparams: HParams, layer: int, params: weights.Layer, sin: jnp.ndarray,
    cos: jnp.ndarray, kv_caches: Sequence[attention.KVCache],
    x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
  """Forward pass through a single layer, returning output, K, V."""

  def my_layer(t, axis=0):
    """Gets the parameters corresponding to a given layer."""
    return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)

  # 2D: [batch.Z, time, embed.XY]
  x = _with_sharding_constraint(
      x, ('residual_batch', 'residual_time', 'residual_embed'))
  xnorm = _layernorm(x)
  # 2D: [batch, time, embed.X]
  xnorm = _with_sharding_constraint(
      xnorm, ('post_norm_batch', 'time', 'post_norm_embed'))
  # jump into manual mode where you want to optimise
  if manual:
    q_wi = shard_map(matmul_2D_wg_manual, mesh
                in_specs=(l2phys('post_norm_batch', 'time', 'post_norm_embed'),
                          l2phys('layers', 'heads', 'embed', 'q_wi_per_head')),
                out_specs=l2phys('post_norm_batch', 'time', 'heads', 'q_wi_per_head'))(xnorm, q_wi, layer)
  else:
    q_wi = jnp.einsum('bte,hed->bthd', xnorm, my_layer(params.q_wi))
    # 2D: [batch, time, heads.YZX, None]
    q_wi = _with_sharding_constraint(q_wi,
                                   ('post_norm_batch', 'time', 'heads', 'qkv'))
  q = q_wi[:, :, :, :hparams.qkv]
  q = _rope(sin, cos, q)
  # unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
  # swiGLU with full d_ff dimension, rather than 2/3 scaled
  wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // hparams.heads)]
  wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads):]
  kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))
  k = kv[:, :, 0, :hparams.qkv]
  v = kv[:, :, 0, hparams.qkv:]
  k = _rope(sin, cos, k)

  y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))

  y_mlp = special2.swish2(wi0) * wi1
  # 2D: [batch, time, heads.YZX, None]
  y_mlp = _with_sharding_constraint(y_mlp,
                                    ('post_norm_batch', 'time', 'heads', None))

  y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
  # do the second half of the mlp and the self-attn projection in parallel
  y_out = jnp.einsum('bthd,hde->bte', y_fused, my_layer(params.o_wo))
  # 2D: [batch.Z, time, embed.XY]
  y_out = _with_sharding_constraint(
      y_out, ('residual_batch', 'residual_time', 'residual_embed'))
  z = y_out + x
  z = _with_sharding_constraint(
      z, ('residual_batch', 'residual_time', 'residual_embed'))
  return z, k, v

在下面的分析中,第一個和第二個 matmul 都被手動降低的版本取代,其中計算(融合)與通訊(ppermute)完全重疊!一個有趣的提示,表明我們正在使用延遲最佳化的變體是 ppmerute 像素是抖動的 — 因為有兩個重疊的 ppermute 同時使用相反的 ICI 軸!

All-to-all 更難重疊,因此被擱置。

image

為什麼 pmapxmap 還沒有解決這個問題?#

pmap 是我們第一個多裝置平行處理 API。它遵循每個裝置程式碼和顯式集合運算的原則。但它有重大的缺點,使其不適合今天的程式

  • 映射多個軸需要巢狀 pmap 巢狀 pmap 不僅難以編寫,而且它們也使得難以控制(甚至預測)資料和計算的裝置放置,並且難以保留資料分片(請參閱接下來的兩個要點)。今天的程式需要多個平行處理軸。

  • 無法控制裝置放置。 特別是在具有多個平行處理軸的情況下,程式設計師需要控制這些軸如何與硬體資源及其通訊拓撲對齊。但是(巢狀)pmap 不提供對映射程式實例如何放置在硬體上的控制;只有使用者無法控制的自動裝置順序。(Gopher 使用 axis_index_groups 和單個非巢狀 pmap 本質上是一種駭客技巧,透過將多個平行處理軸展平為一個來繞過這個問題。)

  • jit/pjit 可組合性。 jit-of-pmap 是一個效能陷阱,巢狀 pmap 也是如此,例如 scan-of-pmap 也是如此,因為當從內部 pmap 返回時,分片不會被保留。為了保留分片,我們需要在 jaxpr 上進行模式比對,以確保我們正在使用完美巢狀的 pmap,或者只是在 jit 內部的 pmap。此外,pjit 在這裡沒有幫助,因為 pmap 目標是 XLA 副本,而 pjit 目標是 XLA SPMD 分割器,並且組合這兩者很困難。

  • jax.Array 相容性(以及因此的 pjit 相容性)。 由於 pmap 輸出的分片無法表示為 Shardings / OpShardings,因為 pmap 的堆疊而非串聯語義,pmap 計算的輸出目前無法傳遞到 pjit 計算,而無需彈跳到主機(或調度重塑計算)。

  • 多控制器語義(以及因此的 pjit 相容性)。 多控制器 pmap 跨控制器串聯值,這運作良好,但與單控制器 pmap 的堆疊語義不同。更實際的是,它排除了使用非完全可定址的 jax.Array 輸入和輸出,正如我們在多控制器 pjit 中所使用的那樣。

  • Eager 模式。 我們沒有將 pmap 設為 eager-first,雖然我們最終(在 4 年多之後!)新增了使用 disable_jit() 的 eager 操作,但 pmapjit 融合到其中的事實意味著它有自己的編譯和調度路徑(實際上是兩個調度路徑:在 Python 中用於處理 Tracer,在 C++ 中用於原始 Array 輸入的效能!),這是一個沉重的實作負擔。

  • 呼叫者中需要重塑。 在 8 個裝置上使用 pmap 的典型用例可能看起來像從大小為 128 的批次軸開始,將其重塑以拆分為大小為 (8, 16) 的兩個軸,然後在第一個軸上進行 pmap。這些重塑很笨拙,並且編譯器通常將它們解釋為複製而不是視圖 — 增加了記憶體和時間使用量。

當僅進行批次資料平行處理時,這些缺點並不算太糟。但是當涉及更多平行處理時,pmap 就是無法勝任!

xmap 作為 pmap 的下一代演進鋪平了道路,並解決了(幾乎)所有這些問題。shmap 追隨 xmap 的腳步,並以基本相同的方式解決了這些問題;實際上,shmap 就像 xmap 的一個特殊子集(有些人稱之為“硬 xmap” 子集),並進行了一些調整。

對於初始原型,我們選擇將 shmap 實作為與 xmap 分開的 primitive,因為限制它支援的功能集可以更輕鬆地專注於核心功能。例如,shmap 不允許未映射的中介值,這使得不必擔心具名軸和自動微分之間的互動變得更容易。此外,不必推理所有功能對之間的互動使得更容易新增超出今天在 xmap 中實作的功能,例如支援 eager 模式。

shmapxmap 都共享大量的降低程式碼。我們可以考慮在未來合併兩者,甚至只專注於 shmap,這取決於使用情況將如何演變。