jax.experimental.multihost_utils.host_local_array_to_global_array#
- jax.experimental.multihost_utils.host_local_array_to_global_array(local_inputs, global_mesh, pspecs)[原始碼]#
將主機本機值轉換為全域分片的 jax.Array。
此函數接受主機本機資料(在不同主機上可能不同),並使用此資料填充全域陣列,其中每個主機上的每個裝置都根據 global_mesh/pspects 定義的分片獲得適當的資料切片。
例如
>>> global_mesh = jax.sharding.Mesh(jax.devices(), 'x') >>> pspecs = jax.sharding.PartitionSpec('x') >>> host_id = jax.process_index() >>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs) # NB: assumes jax.local_device_count() divides 4.
產生的陣列將具有形狀 (4 * num_processes) 並且將具有分散式值:(0, 1, 2, 3, 0, 2, 4, 6, 0, 3, 6, 9, … ),其中每個切片 np.arange(4) * host_id 將在相應主機的裝置之間進行分割。
類似地
>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()), ['host', 'dev']) >>> pspecs = jax.sharding.PartitionSpec('host') >>> host_id = jax.process_index() >>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs)
將建立相同的分散式值 (0, 1, 2, 3, 0, 2, 4, 6, …),但是每個切片 np.arange(4) * i 將在相應的主機裝置之間複製。
另一方面,如果 pspecs = PartitionSpec(),這表示跨所有軸複製,則此程式碼片段
>>> pspecs = jax.sharding.PartitionSpec() >>> arr = host_local_array_to_global_array(np.arange(4), mesh, pspecs)
將具有形狀 (4,),並且值 (0, 1, 2, 3) 將在所有主機和裝置之間複製。
對於 pspec 指示資料複製的情況,使用非完全相同的 local_inputs 是未定義的行為。
您可以使用此函數轉換到 jax.Array。 將 jax.Array 與 pjit 結合使用具有與將 GDA 與 pjit 結合使用相同的語義,即 pjit 的所有 jax.Array 輸入都應為全域形狀。
如果您目前將主機本機值傳遞給 pjit,則可以使用此函數將您的主機本機值轉換為全域陣列,然後將其傳遞給 pjit。
使用範例。
>>> from jax.experimental import multihost_utils >>> >>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) >>> >>> with mesh: >>> global_out = pjitted_fun(global_inputs) >>> >>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs)
請注意,此函數要求全域網格為連續網格,這表示屬於每個主機的裝置應在此網格中形成一個子立方體。 若要將本機資料移動到具有非連續網格的全域陣列,請改用 jax.make_array_from_callback 或 jax.make_array_from_single_device_arrays。
- 參數:
local_inputs (Any) – 主機本機值的 Pytree。
global_mesh (jax.sharding.Mesh) – jax.sharding.Mesh 物件。 網格必須是連續網格,
mesh. (也就是說,所有主機的裝置都必須在此網格中形成一個子立方體)
pspecs (Any) – jax.sharding.PartitionSpec 的 Pytree。
- 回傳值:
全域陣列的 Pytree。