jax.sharding
模組#
類別#
- class jax.sharding.Sharding#
描述
jax.Array
如何跨裝置佈局。- addressable_devices_indices_map(global_shape)[source]#
從可定址裝置到每個裝置包含的陣列資料切片的映射。
addressable_devices_indices_map
包含device_indices_map
中適用於可定址裝置的部分。- 參數:
global_shape (Shape)
- 回傳類型:
Mapping[Device, Index | None]
- property device_set: set[Device][source]#
此
Sharding
跨越的裝置集合。在多控制器 JAX 中,裝置集合是全域的,即包含來自其他進程的不可定址裝置。
- devices_indices_map(global_shape)[source]#
回傳從裝置到每個裝置包含的陣列切片的映射。
此映射包含所有全域裝置,即包含來自其他進程的不可定址裝置。
- 參數:
global_shape (Shape)
- 回傳類型:
Mapping[Device, Index]
- is_equivalent_to(other, ndim)[source]#
如果兩個分片相等,則回傳
True
。如果兩個分片將相同的邏輯陣列分片放置在相同的裝置上,則它們是相等的。
例如,如果
NamedSharding
和PositionalSharding
都將陣列的相同分片放置在相同的裝置上,則它們可能是相等的。
- property is_fully_addressable: bool[source]#
此分片是否完全可定址?
如果目前進程可以定址
Sharding
中命名的所有裝置,則分片是完全可定址的。在多進程 JAX 中,is_fully_addressable
等同於 “is_local”。
- class jax.sharding.SingleDeviceSharding#
基底類別:
Sharding
將資料放置在單一裝置上的
Sharding
。- 參數:
device – 單一
Device
。
範例
>>> single_device_sharding = jax.sharding.SingleDeviceSharding( ... jax.devices()[0])
- property device_set: set[Device][source]#
此
Sharding
跨越的裝置集合。在多控制器 JAX 中,裝置集合是全域的,即包含來自其他進程的不可定址裝置。
- devices_indices_map(global_shape)[source]#
回傳從裝置到每個裝置包含的陣列切片的映射。
此映射包含所有全域裝置,即包含來自其他進程的不可定址裝置。
- 參數:
global_shape (Shape)
- 回傳類型:
Mapping[Device, Index]
- property is_fully_addressable: bool[source]#
此分片是否完全可定址?
如果目前進程可以定址
Sharding
中命名的所有裝置,則分片是完全可定址的。在多進程 JAX 中,is_fully_addressable
等同於 “is_local”。
- class jax.sharding.NamedSharding#
基底類別:
Sharding
NamedSharding
使用命名軸表示分片。NamedSharding
是一對裝置Mesh
和PartitionSpec
,用於描述如何跨該網格分片陣列。Mesh
是 JAX 裝置的多維 NumPy 陣列,其中網格的每個軸都有一個名稱,例如'x'
或'y'
。PartitionSpec
是一個元組,其元素可以是None
、網格軸或網格軸元組。每個元素描述輸入維度如何跨零個或多個網格維度進行分割。例如,PartitionSpec('x', 'y')
表示資料的第一個維度跨網格的x
軸分片,第二個維度跨網格的y
軸分片。分散式陣列和自動平行化 (https://jax.dev.org.tw/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) 教學課程有更多詳細資訊和圖表,說明如何使用
Mesh
和PartitionSpec
。- 參數:
mesh –
jax.sharding.Mesh
物件。spec –
jax.sharding.PartitionSpec
物件。
範例
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
- property device_set: set[Device][source]#
此
Sharding
跨越的裝置集合。在多控制器 JAX 中,裝置集合是全域的,即包含來自其他進程的不可定址裝置。
- 屬性 is_fully_addressable: bool[原始碼]#
此分片是否完全可定址?
如果目前進程可以定址
Sharding
中命名的所有裝置,則分片是完全可定址的。在多進程 JAX 中,is_fully_addressable
等同於 “is_local”。
- 屬性 mesh#
(self) -> 物件
- 屬性 spec#
(self) -> 物件
- 類別 jax.sharding.PositionalSharding(devices, *, memory_kind=None)[原始碼]#
基底類別:
Sharding
- 參數:
devices (序列[xc.Device] | np.ndarray)
memory_kind (字串 | None)
- 屬性 is_fully_addressable: bool#
此分片是否完全可定址?
如果目前進程可以定址
Sharding
中命名的所有裝置,則分片是完全可定址的。在多進程 JAX 中,is_fully_addressable
等同於 “is_local”。
- 類別 jax.sharding.PmapSharding#
基底類別:
Sharding
描述由
jax.pmap()
使用的分片。- classmethod default(shape, sharded_dim=0, devices=None)[原始碼]#
建立一個
PmapSharding
,其符合jax.pmap()
使用的預設放置方式。- 參數:
shape (Shape) – 輸入陣列的形狀。
sharded_dim (整數 | None) – 輸入陣列被分片的維度。預設為 0。
devices (序列[xc.Device] | None | None) – 可選的裝置序列以供使用。如果省略,則為隱含的
used (pmap 使用的裝置順序為) –
jax.local_devices()
。of (其順序為) –
jax.local_devices()
。
- 回傳類型:
- 屬性 devices#
(self) -> ndarray
- devices_indices_map(global_shape)[原始碼]#
回傳從裝置到每個裝置包含的陣列切片的映射。
此映射包含所有全域裝置,即包含來自其他進程的不可定址裝置。
- 參數:
global_shape (Shape)
- 回傳類型:
Mapping[Device, Index]
- is_equivalent_to(other, ndim)[原始碼]#
如果兩個分片相等,則回傳
True
。如果兩個分片將相同的邏輯陣列分片放置在相同的裝置上,則它們是相等的。
例如,如果
NamedSharding
和PositionalSharding
都將陣列的相同分片放置在相同的裝置上,則它們可能是相等的。- 參數:
self (PmapSharding)
other (PmapSharding)
ndim (int)
- 回傳類型:
- 屬性 is_fully_addressable: bool#
此分片是否完全可定址?
如果目前進程可以定址
Sharding
中命名的所有裝置,則分片是完全可定址的。在多進程 JAX 中,is_fully_addressable
等同於 “is_local”。
- shard_shape(global_shape)[原始碼]#
回傳每個裝置上資料的形狀。
此函式回傳的分片形狀是根據
global_shape
和分片的屬性計算得出的。- 參數:
global_shape (Shape)
- 回傳類型:
Shape
- 屬性 sharding_spec#
(self) -> jax::ShardingSpec
- 類別 jax.sharding.GSPMDSharding#
基底類別:
Sharding
- 屬性 is_fully_addressable: bool#
此分片是否完全可定址?
如果目前進程可以定址
Sharding
中命名的所有裝置,則分片是完全可定址的。在多進程 JAX 中,is_fully_addressable
等同於 “is_local”。
- 類別 jax.sharding.PartitionSpec(*partitions)[原始碼]#
元組,描述如何在裝置網格上分割陣列。
每個元素可以是
None
、字串或字串元組。請參閱jax.sharding.NamedSharding
的文件以了解更多詳細資訊。此類別的存在是為了讓 JAX 的 pytree 工具程式可以區分分片規格與應視為 pytree 的元組。
- 類別 jax.sharding.Mesh(devices, axis_names, *, axis_types=None)[原始碼]#
宣告在此管理器範圍內可用的硬體資源。
特別是,所有
axis_names
都成為受管理區塊內的有效資源名稱,並且可以用於例如jax.experimental.pjit.pjit()
的in_axis_resources
引數中。另請參閱 JAX 的多進程程式設計模型 (https://jax.dev.org.tw/en/latest/multi_process.html) 和分散式陣列與自動平行化教學 (https://jax.dev.org.tw/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)如果您正在多個執行緒中編譯,請確保
with Mesh
上下文管理器位於執行緒將執行的函數內。- 參數:
devices (np.ndarray) – 包含 JAX 裝置物件的 NumPy ndarray 物件 (例如從
jax.devices()
取得)。axis_names (tuple[MeshAxisName, ...]) – 要指派給裝置引數維度的資源軸名稱序列。其長度應符合
devices
的秩。axis_types (MeshAxisType)
範例
>>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> inp = np.arange(16).reshape((8, 2)) >>> devices = np.array(jax.devices()).reshape(4, 2) ... >>> # Declare a 2D mesh with axes `x` and `y`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> # Use the mesh object directly as a context manager. >>> with global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager. >>> with Mesh(devices, ('x', 'y')) as global_mesh: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`. >>> global_mesh = Mesh(devices, ('x', 'y')) >>> with global_mesh as m: ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`. >>> with Mesh(devices, ('x', 'y')): ... out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)