jax.experimental.shard_map.shard_map#

jax.experimental.shard_map.shard_map(f, mesh, in_specs, out_specs, check_rep=True, auto=frozenset({}))[原始碼]#

將函式對應到資料分片。

注意

shard_map 是一個實驗性 API,仍可能變更。如需分片資料的簡介,請參閱平行程式設計簡介。如需更深入了解如何使用 shard_map,請參閱使用 shard_map 的 SPMD 多裝置平行處理

參數:
  • f (Callable) – 要對應的可呼叫物件。f 的每次應用,或 f 的「實例」,都會將已對應引數的分片作為輸入,並產生輸出的分片。

  • mesh (Mesh | AbstractMesh) – jax.sharding.Mesh,代表用於對資料進行分片以及執行 f 實例的裝置陣列。Mesh 的名稱可用於 f 中的集體通訊操作。這通常由公用程式函式建立,例如 jax.experimental.mesh_utils.create_device_mesh()

  • in_specs (Specs) – 一個 pytree,其葉節點為 PartitionSpec 實例,其樹狀結構是要對應之 args tuple 的樹狀前綴。與 NamedSharding 類似,每個 PartitionSpec 代表對應的引數 (或引數子樹) 應如何沿著 mesh 的具名軸分片。在每個 PartitionSpec 中,在某個位置提及 mesh 軸名稱表示沿著該位置軸對應的引數陣列軸進行分片;不提及軸名稱表示複製。如果引數或引數子樹具有 None 的對應規格,則該引數不會分片。

  • out_specs (Specs) – 一個 pytree,其葉節點為 PartitionSpec 實例,其樹狀結構是 f 輸出的樹狀前綴。每個 PartitionSpec 代表應如何串連對應的輸出分片。在每個 PartitionSpec 中,在某個位置提及 mesh 軸名稱表示沿著對應的位置軸串連該網格軸的分片。不提及 mesh 軸名稱表示承諾輸出值在該網格軸上相等,並且應僅產生單一值,而不是串連。

  • check_rep (bool) – 如果為 True (預設值),則啟用額外的有效性檢查和自動微分最佳化。有效性檢查與未在 out_specs 中提及的任何網格軸名稱是否與 f 的輸出如何複製一致有關。如果在 f 中使用 Pallas 核心,則必須設為 False。

  • auto (frozenset[AxisName]) – (實驗性) 一組來自 mesh 的選用軸名稱,我們不會對這些軸名稱進行資料分片或對應函式,而是允許編譯器控制分片。這些名稱不能用於 in_specsout_specsf 中的通訊集合。

返回:

一個可呼叫物件,用於跨根據 meshin_specs 分片的資料應用輸入函式 f

範例

如需範例,請參閱平行程式設計簡介使用 shard_map 的 SPMD 多裝置平行處理