jax.make_array_from_single_device_arrays#
- jax.make_array_from_single_device_arrays(shape, sharding, arrays)[原始碼]#
- 從單一裝置陣列的序列傳回
jax.Array
。 輸入
sharding
的網格中的每個裝置都必須在arrays
中有一個陣列。
- 參數:
shape (Shape) – 輸出
jax.Array
的形狀。這傳達了已經包含在sharding
和arrays
中的資訊,並作為雙重檢查。sharding (Sharding) – Sharding:全域 Sharding 實例,描述輸出 jax.Array 如何在裝置之間佈局。
arrays (Sequence[basearray.Array]) – 每個都是單一裝置可定址的
jax.Array
序列。len(arrays)
必須等於len(sharding.addressable_devices)
,且每個陣列的形狀必須相同。對於多進程程式碼,每個進程將使用不同的arrays
引數呼叫,該引數對應於該進程的資料。這些陣列通常透過jax.device_put
建立。
- 傳回值:
- 全域
jax.Array
,以sharding
分片,形狀等於shape
,且每個裝置 內容與
arrays
相符。
- 全域
- 傳回類型:
ArrayImpl
範例
>>> import math >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> mesh_rows = 2 >>> mesh_cols = jax.device_count() // 2 ... >>> global_shape = (8, 8) >>> mesh = Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y')) >>> sharding = jax.sharding.NamedSharding(mesh, P('x', 'y')) >>> inp_data = np.arange(math.prod(global_shape)).reshape(global_shape) ... >>> arrays = [ ... jax.device_put(inp_data[index], d) ... for d, index in sharding.addressable_devices_indices_map(global_shape).items()] ... >>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays) >>> assert arr.shape == (8,8) # arr.shape is (8,8) regardless of jax.device_count()
對於您有本機陣列並想要將其轉換為全域 jax.Array 的情況,請使用
jax.make_array_from_process_local_data
。- 從單一裝置陣列的序列傳回