jax.make_array_from_single_device_arrays#

jax.make_array_from_single_device_arrays(shape, sharding, arrays)[原始碼]#
從單一裝置陣列的序列傳回 jax.Array

輸入 sharding 的網格中的每個裝置都必須在 arrays 中有一個陣列。

參數:
  • shape (Shape) – 輸出 jax.Array 的形狀。這傳達了已經包含在 shardingarrays 中的資訊,並作為雙重檢查。

  • sharding (Sharding) – Sharding:全域 Sharding 實例,描述輸出 jax.Array 如何在裝置之間佈局。

  • arrays (Sequence[basearray.Array]) – 每個都是單一裝置可定址的 jax.Array 序列。len(arrays) 必須等於 len(sharding.addressable_devices),且每個陣列的形狀必須相同。對於多進程程式碼,每個進程將使用不同的 arrays 引數呼叫,該引數對應於該進程的資料。這些陣列通常透過 jax.device_put 建立。

傳回值:

全域 jax.Array,以 sharding 分片,形狀等於 shape,且每個裝置

內容與 arrays 相符。

傳回類型:

ArrayImpl

範例

>>> import math
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> mesh_rows = 2
>>> mesh_cols =  jax.device_count() // 2
...
>>> global_shape = (8, 8)
>>> mesh = Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
>>> inp_data = np.arange(math.prod(global_shape)).reshape(global_shape)
...
>>> arrays = [
...    jax.device_put(inp_data[index], d)
...        for d, index in sharding.addressable_devices_indices_map(global_shape).items()]
...
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
>>> assert arr.shape == (8,8) # arr.shape is (8,8) regardless of jax.device_count()

對於您有本機陣列並想要將其轉換為全域 jax.Array 的情況,請使用 jax.make_array_from_process_local_data