jax.experimental.mesh_utils.create_hybrid_device_mesh#

jax.experimental.mesh_utils.create_hybrid_device_mesh(mesh_shape, dcn_mesh_shape, devices=None, *, process_is_granule=False, should_sort_granules_by_key=True, allow_split_physical_axes=False)[source]#

為混合式(例如 ICI 和 DCN)平行處理建立裝置網格。

參數:
  • mesh_shape (Sequence[int]) – 較快/內部網路的邏輯網格形狀,依網路強度遞增排序,例如 [replica, data, mdl],其中 mdl 具有最高的網路通訊需求。

  • dcn_mesh_shape (Sequence[int]) – 較慢/外部網路的邏輯網格形狀,順序與 mesh_shape 相同。

  • devices (Sequence[Any] | None | None) – (可選)用於建構網格的裝置。預設為 jax.devices()。

  • process_is_granule (bool) – 若為 True,此函式會將進程視為較慢/外部網路的單位。否則,它會尋找裝置上的 slice_index 屬性,並使用切片作為單位。啟用此選項旨在作為不設定 slice_index 的平台的後備方案。

  • should_sort_granules_by_key (bool) – 是否應依據 Granule 金鑰(切片或進程索引,取決於 process_is_granule)排序裝置 Granule。

  • allow_split_physical_axes (bool) – 若為 True,必要時我們會分割物理軸,以產生所需的裝置網格。

引發:

ValueError – 如果 devices 所屬的切片數量不等於 dcn_mesh_shape 的乘積,或者任何單一切片所屬的裝置數量不等於 mesh_shape 的乘積。

返回:

形狀為 mesh_shape * dcn_mesh_shape 的 JAX 裝置 np.ndarray,可饋入 jax.sharding.Mesh 以進行混合式平行處理。

返回類型:

np.ndarray