jax.make_array_from_process_local_data#
- jax.make_array_from_process_local_data(sharding, local_data, global_shape=None)[原始碼]#
使用進程中可用的資料建立分散式張量。
此函式是 make_array_from_callback 的常見特例。它假設資料在進程中可用,並處理索引操作。
最常見的情況是當分片跨批次維度分片,且每個主機僅載入其對應的子批次時。此函式也支援更一般的情況,例如混合多主機和多軸複製與分片,但您需要正確計算進程本地資料的大小和內容,以滿足分片約束。
特別是,如果任何兩個主機是副本,則 host_local_data 也應相同。
global_shape 是可選的。如果未提供,將從 local_data 和分片推斷,假設每個主機僅代表其自身的均勻分片資料。如果分片是非均勻的(請參閱以下註解),則會引發例外。
明確設定 global_shape 可以實現更精細的控制,並適用於非均勻分片。global_shape 的每個維度必須與 host_local_data 相符,或與分片的推斷全域形狀相符(在這種情況下,它等同於將其設定為 None,但更明確)。
例如,如果維度 i 完全分片,則此大小將為 per_device_shape[i] * jax.local_device_count()。每個裝置將被映射到 local_data 陣列的本地切片中。例如,如果給定進程定址切片 (8, 12) 和 (24, 28),則這些切片將被映射到 local_data 的 (0, 4) 和 (4, 8)。
對於 global_shapes 與 local_shape 相符的每個維度,每個裝置將在 local_data 中查找切片。例如,如果 global_shape == local_data.shape,則假設本地資料是要分片到裝置的實際目標陣列。
如果 global_shape 與 local_data.shape 相同,則資料在所有主機之間必須相同。
範例
>>> from jax.sharding import PartitionSpec as P >>> mesh_rows = 2 >>> mesh_cols = jax.device_count() // 2 ... >>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(mesh, P(('x', 'y'),)) >>> rows_per_device = 2 >>> feature_length = 32 >>> per_device_shape = (rows_per_device, feature_length) >>> per_host_shape = (rows_per_device * len(mesh.local_devices), feature_length) >>> per_host_generator = lambda : np.arange(np.prod(per_host_shape)).reshape(per_host_shape) >>> per_host_data = per_host_generator() # replace with your own per-host data pipeline that outputs numpy arrays >>> global_shape = (rows_per_device * len(sharding.device_set), ) + per_device_shape[1:] >>> output_global_array = jax.make_array_from_process_local_data(sharding, per_host_data, global_shape) ... >>> assert output_global_array.addressable_data(0).shape == per_device_shape >>> assert output_global_array.shape == global_shape
注意:雖然大多數分片是均勻的,但可以設計奇特的分片網格,其中每個進程的裝置在某些維度中將以非網格狀模式排列,或者索引以非平凡的方式重疊。這種分片在這些維度中稱為「非均勻」。在這種情況下,沿這些方向的全域形狀必須與本地形狀相符,因為沒有有意義的方式以非重疊方式表示所有需要的每個進程資料。例如,對於全域形狀 4x4,如果分片看起來像這樣
0123 2103 4675 4567
使用 4 個進程,分別包含裝置 (0,1)、(2, 3)、(4, 5)、(6, 7)。然後每個主機的資料看起來像
xx.. ..xx …. …. .xx. x..x …. …. …. …. x..x .xx. …. …. xx.. ..xx
分片在列上是均勻的(每個主機需要第 1-2 列或第 3-4 列),在行上是非均勻的(主機需要重疊但不匹配的行集合)。因此,所有主機的本地資料必須具有形狀 2x4 或 4x4,即使每個主機都可以潛在地位於 2x2 形狀中。在這種情況下,使用者必須明確提供 global_shape,對於 local_shape=(2, 4),潛在有效的全域形狀為 (2, 4) 和 (4, 4)。
另一方面,對於分片
0213 x.x. .x.x. …. …. 0213 x.x. .x.x. …. …. 4657 …. …. .x.x x.x. 4657 …. …. .x.x x.x.
對於 local_shape=(2, 2),此函式可以接受 2x2、2x4、4x2 和 4x4 全域形狀的選擇。在這種情況下,將 global_shape 設定為 None 等同於將其設定為 (4, 4)。
- 參數:
sharding (Sharding) – 全域陣列的分片。
local_data (np.ndarray) – 主機上的資料,將放置在本地裝置上。每個維度應與 global_shape 相符,或與 num_addressable_indices(dim) 相符。
global_shape (Shape | None | None) – 全域陣列的目標形狀。如果為 None,將從 local_data 和分片推斷。
- 傳回:
將具有 sharding=sharding 且形狀為 global_shape 的張量。
- 傳回類型:
ArrayImpl