jax.make_array_from_process_local_data#

jax.make_array_from_process_local_data(sharding, local_data, global_shape=None)[原始碼]#

使用進程中可用的資料建立分散式張量。

此函式是 make_array_from_callback 的常見特例。它假設資料在進程中可用,並處理索引操作。

最常見的情況是當分片跨批次維度分片,且每個主機僅載入其對應的子批次時。此函式也支援更一般的情況,例如混合多主機和多軸複製與分片,但您需要正確計算進程本地資料的大小和內容,以滿足分片約束。

特別是,如果任何兩個主機是副本,則 host_local_data 也應相同。

global_shape 是可選的。如果未提供,將從 local_data 和分片推斷,假設每個主機僅代表其自身的均勻分片資料。如果分片是非均勻的(請參閱以下註解),則會引發例外。

明確設定 global_shape 可以實現更精細的控制,並適用於非均勻分片。global_shape 的每個維度必須與 host_local_data 相符,或與分片的推斷全域形狀相符(在這種情況下,它等同於將其設定為 None,但更明確)。

例如,如果維度 i 完全分片,則此大小將為 per_device_shape[i] * jax.local_device_count()。每個裝置將被映射到 local_data 陣列的本地切片中。例如,如果給定進程定址切片 (8, 12) 和 (24, 28),則這些切片將被映射到 local_data 的 (0, 4) 和 (4, 8)。

對於 global_shapes 與 local_shape 相符的每個維度,每個裝置將在 local_data 中查找切片。例如,如果 global_shape == local_data.shape,則假設本地資料是要分片到裝置的實際目標陣列。

如果 global_shape 與 local_data.shape 相同,則資料在所有主機之間必須相同。

範例

>>> from jax.sharding import PartitionSpec as P
>>> mesh_rows = 2
>>> mesh_cols =  jax.device_count() // 2
...
>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(mesh, P(('x', 'y'),))
>>> rows_per_device = 2
>>> feature_length = 32
>>> per_device_shape = (rows_per_device, feature_length)
>>> per_host_shape = (rows_per_device * len(mesh.local_devices), feature_length)
>>> per_host_generator = lambda : np.arange(np.prod(per_host_shape)).reshape(per_host_shape)
>>> per_host_data = per_host_generator()  # replace with your own per-host data pipeline that outputs numpy arrays
>>> global_shape = (rows_per_device * len(sharding.device_set), ) + per_device_shape[1:]
>>> output_global_array = jax.make_array_from_process_local_data(sharding, per_host_data, global_shape)
...
>>> assert output_global_array.addressable_data(0).shape == per_device_shape
>>> assert output_global_array.shape == global_shape

注意:雖然大多數分片是均勻的,但可以設計奇特的分片網格,其中每個進程的裝置在某些維度中將以非網格狀模式排列,或者索引以非平凡的方式重疊。這種分片在這些維度中稱為「非均勻」。在這種情況下,沿這些方向的全域形狀必須與本地形狀相符,因為沒有有意義的方式以非重疊方式表示所有需要的每個進程資料。例如,對於全域形狀 4x4,如果分片看起來像這樣

0123 2103 4675 4567

使用 4 個進程,分別包含裝置 (0,1)、(2, 3)、(4, 5)、(6, 7)。然後每個主機的資料看起來像

xx.. ..xx …. …. .xx. x..x …. …. …. …. x..x .xx. …. …. xx.. ..xx

分片在列上是均勻的(每個主機需要第 1-2 列或第 3-4 列),在行上是非均勻的(主機需要重疊但不匹配的行集合)。因此,所有主機的本地資料必須具有形狀 2x4 或 4x4,即使每個主機都可以潛在地位於 2x2 形狀中。在這種情況下,使用者必須明確提供 global_shape,對於 local_shape=(2, 4),潛在有效的全域形狀為 (2, 4) 和 (4, 4)。

另一方面,對於分片

0213 x.x. .x.x. …. …. 0213 x.x. .x.x. …. …. 4657 …. …. .x.x x.x. 4657 …. …. .x.x x.x.

對於 local_shape=(2, 2),此函式可以接受 2x2、2x4、4x2 和 4x4 全域形狀的選擇。在這種情況下,將 global_shape 設定為 None 等同於將其設定為 (4, 4)。

參數:
  • sharding (Sharding) – 全域陣列的分片。

  • local_data (np.ndarray) – 主機上的資料,將放置在本地裝置上。每個維度應與 global_shape 相符,或與 num_addressable_indices(dim) 相符。

  • global_shape (Shape | None | None) – 全域陣列的目標形狀。如果為 None,將從 local_data 和分片推斷。

傳回:

將具有 sharding=sharding 且形狀為 global_shape 的張量。

傳回類型:

ArrayImpl