jax.experimental.mesh_utils.create_device_mesh#

jax.experimental.mesh_utils.create_device_mesh(mesh_shape, devices=None, *, contiguous_submeshes=False, allow_split_physical_axes=False)[原始碼]#

為 jax.sharding.Mesh 建立高效能的裝置網格。

參數:
  • mesh_shape (Sequence[int]) – 邏輯網格的形狀,依網路密集度遞增排序,例如 [replica, data, mdl],其中 mdl 具有最多的網路通訊需求。

  • devices (Sequence[Any] | None | None) – (可選)用於建構網格的裝置。預設為 jax.devices()。

  • contiguous_submeshes (bool) – 如果為 True,此函式將嘗試建立一個網格,其中每個進程的本機裝置形成一個連續的子網格。如果此函式無法產生合適的網格,則會引發 ValueError。在引入 jax.Array 之前,此設定有時是必要的,以確保非參差不齊的本機陣列;如果使用 jax.Arrays,最好將此設定保持為 False。

  • allow_split_physical_axes (bool) – 如果為 True,我們將在必要時分割物理軸,以產生所需的裝置網格。

引發:

ValueError – 如果裝置數量不等於 mesh_shape 的乘積。

回傳:

一個 np.ndarray 形式的 JAX 裝置,其形狀為 mesh_shape,可以饋送到 jax.sharding.Mesh 中,以獲得良好的集體效能。

回傳型別:

np.ndarray