jax.experimental.multihost_utils.host_local_array_to_global_array#

jax.experimental.multihost_utils.host_local_array_to_global_array(local_inputs, global_mesh, pspecs)[原始碼]#

將主機本機值轉換為全域分片的 jax.Array。

此函數接受主機本機資料(在不同主機上可能不同),並使用此資料填充全域陣列,其中每個主機上的每個裝置都根據 global_mesh/pspects 定義的分片獲得適當的資料切片。

例如

>>> global_mesh = jax.sharding.Mesh(jax.devices(), 'x')
>>> pspecs = jax.sharding.PartitionSpec('x')
>>> host_id = jax.process_index()
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs)  # NB: assumes jax.local_device_count() divides 4.   

產生的陣列將具有形狀 (4 * num_processes) 並且將具有分散式值:(0, 1, 2, 3, 0, 2, 4, 6, 0, 3, 6, 9, … ),其中每個切片 np.arange(4) * host_id 將在相應主機的裝置之間進行分割。

類似地

>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()), ['host', 'dev'])
>>> pspecs = jax.sharding.PartitionSpec('host')
>>> host_id = jax.process_index()
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs)  

將建立相同的分散式值 (0, 1, 2, 3, 0, 2, 4, 6, …),但是每個切片 np.arange(4) * i 將在相應的主機裝置之間複製

另一方面,如果 pspecs = PartitionSpec(),這表示跨所有軸複製,則此程式碼片段

>>> pspecs = jax.sharding.PartitionSpec()
>>> arr = host_local_array_to_global_array(np.arange(4), mesh, pspecs)  

將具有形狀 (4,),並且值 (0, 1, 2, 3) 將在所有主機和裝置之間複製。

對於 pspec 指示資料複製的情況,使用非完全相同的 local_inputs 是未定義的行為。

您可以使用此函數轉換到 jax.Array。 將 jax.Array 與 pjit 結合使用具有與將 GDA 與 pjit 結合使用相同的語義,即 pjit 的所有 jax.Array 輸入都應為全域形狀。

如果您目前將主機本機值傳遞給 pjit,則可以使用此函數將您的主機本機值轉換為全域陣列,然後將其傳遞給 pjit。

使用範例。

>>> from jax.experimental import multihost_utils 
>>>
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) 
>>>
>>> with mesh: 
>>>   global_out = pjitted_fun(global_inputs) 
>>>
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) 

請注意,此函數要求全域網格為連續網格,這表示屬於每個主機的裝置應在此網格中形成一個子立方體。 若要將本機資料移動到具有非連續網格的全域陣列,請改用 jax.make_array_from_callback 或 jax.make_array_from_single_device_arrays。

參數:
  • local_inputs (Any) – 主機本機值的 Pytree。

  • global_mesh (jax.sharding.Mesh) – jax.sharding.Mesh 物件。 網格必須是連續網格,

  • mesh. (也就是說,所有主機的裝置都必須在此網格中形成一個子立方體)

  • pspecs (Any) – jax.sharding.PartitionSpec 的 Pytree。

回傳值:

全域陣列的 Pytree。