jax.sharding 模組#

類別#

class jax.sharding.Sharding#

描述 jax.Array 如何跨裝置佈局。

property addressable_devices: set[Device]#

在目前進程中可定址的 Sharding 裝置集合。

addressable_devices_indices_map(global_shape)[source]#

從可定址裝置到每個裝置包含的陣列資料切片的映射。

addressable_devices_indices_map 包含 device_indices_map 中適用於可定址裝置的部分。

參數:

global_shape (Shape)

回傳類型:

Mapping[Device, Index | None]

property device_set: set[Device][source]#

Sharding 跨越的裝置集合。

在多控制器 JAX 中,裝置集合是全域的,即包含來自其他進程的不可定址裝置。

devices_indices_map(global_shape)[source]#

回傳從裝置到每個裝置包含的陣列切片的映射。

此映射包含所有全域裝置,即包含來自其他進程的不可定址裝置。

參數:

global_shape (Shape)

回傳類型:

Mapping[Device, Index]

is_equivalent_to(other, ndim)[source]#

如果兩個分片相等,則回傳 True

如果兩個分片將相同的邏輯陣列分片放置在相同的裝置上,則它們是相等的。

例如,如果 NamedShardingPositionalSharding 都將陣列的相同分片放置在相同的裝置上,則它們可能是相等的。

參數:
回傳類型:

bool

property is_fully_addressable: bool[source]#

此分片是否完全可定址?

如果目前進程可以定址 Sharding 中命名的所有裝置,則分片是完全可定址的。在多進程 JAX 中,is_fully_addressable 等同於 “is_local”。

property is_fully_replicated: bool[source]#

此分片是否完全複製?

如果每個裝置都擁有完整資料的完整副本,則分片是完全複製的。

property memory_kind: str | None[source]#

回傳分片的記憶體種類。

property num_devices: int[source]#

分片包含的裝置數量。

shard_shape(global_shape)[source]#

回傳每個裝置上資料的形狀。

此函式回傳的分片形狀是根據 global_shape 和分片的屬性計算得出的。

參數:

global_shape (Shape)

回傳類型:

Shape

with_memory_kind(kind)[source]#

回傳具有指定記憶體種類的新 Sharding 實例。

參數:

kind (str)

回傳類型:

Sharding

class jax.sharding.SingleDeviceSharding#

基底類別:Sharding

將資料放置在單一裝置上的 Sharding

參數:

device – 單一 Device

範例

>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
...     jax.devices()[0])
property device_set: set[Device][source]#

Sharding 跨越的裝置集合。

在多控制器 JAX 中,裝置集合是全域的,即包含來自其他進程的不可定址裝置。

devices_indices_map(global_shape)[source]#

回傳從裝置到每個裝置包含的陣列切片的映射。

此映射包含所有全域裝置,即包含來自其他進程的不可定址裝置。

參數:

global_shape (Shape)

回傳類型:

Mapping[Device, Index]

property is_fully_addressable: bool[source]#

此分片是否完全可定址?

如果目前進程可以定址 Sharding 中命名的所有裝置,則分片是完全可定址的。在多進程 JAX 中,is_fully_addressable 等同於 “is_local”。

property is_fully_replicated: bool[source]#

此分片是否完全複製?

如果每個裝置都擁有完整資料的完整副本,則分片是完全複製的。

property memory_kind: str | None[source]#

回傳分片的記憶體種類。

property num_devices: int[source]#

分片包含的裝置數量。

with_memory_kind(kind)[source]#

回傳具有指定記憶體種類的新 Sharding 實例。

參數:

kind (str)

回傳類型:

SingleDeviceSharding

class jax.sharding.NamedSharding#

基底類別:Sharding

NamedSharding 使用命名軸表示分片。

NamedSharding 是一對裝置 MeshPartitionSpec,用於描述如何跨該網格分片陣列。

Mesh 是 JAX 裝置的多維 NumPy 陣列,其中網格的每個軸都有一個名稱,例如 'x''y'

PartitionSpec 是一個元組,其元素可以是 None、網格軸或網格軸元組。每個元素描述輸入維度如何跨零個或多個網格維度進行分割。例如,PartitionSpec('x', 'y') 表示資料的第一個維度跨網格的 x 軸分片,第二個維度跨網格的 y 軸分片。

分散式陣列和自動平行化 (https://jax.dev.org.tw/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) 教學課程有更多詳細資訊和圖表,說明如何使用 MeshPartitionSpec

參數:

範例

>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
property addressable_devices: set[Device][source]#

在目前進程中可定址的 Sharding 裝置集合。

property device_set: set[Device][source]#

Sharding 跨越的裝置集合。

在多控制器 JAX 中,裝置集合是全域的,即包含來自其他進程的不可定址裝置。

屬性 is_fully_addressable: bool[原始碼]#

此分片是否完全可定址?

如果目前進程可以定址 Sharding 中命名的所有裝置,則分片是完全可定址的。在多進程 JAX 中,is_fully_addressable 等同於 “is_local”。

屬性 is_fully_replicated: bool#

此分片是否完全複製?

如果每個裝置都擁有完整資料的完整副本,則分片是完全複製的。

屬性 memory_kind: str | None[原始碼]#

回傳分片的記憶體種類。

屬性 mesh#

(self) -> 物件

屬性 num_devices: int[原始碼]#

分片包含的裝置數量。

屬性 spec#

(self) -> 物件

with_memory_kind(kind)[原始碼]#

回傳具有指定記憶體種類的新 Sharding 實例。

參數:

kind (str)

回傳類型:

NamedSharding

類別 jax.sharding.PositionalSharding(devices, *, memory_kind=None)[原始碼]#

基底類別:Sharding

參數:
  • devices (序列[xc.Device] | np.ndarray)

  • memory_kind (字串 | None)

屬性 device_set: set[xc.Device]#

Sharding 跨越的裝置集合。

在多控制器 JAX 中,裝置集合是全域的,即包含來自其他進程的不可定址裝置。

屬性 is_fully_addressable: bool#

此分片是否完全可定址?

如果目前進程可以定址 Sharding 中命名的所有裝置,則分片是完全可定址的。在多進程 JAX 中,is_fully_addressable 等同於 “is_local”。

屬性 is_fully_replicated: bool#

此分片是否完全複製?

如果每個裝置都擁有完整資料的完整副本,則分片是完全複製的。

屬性 memory_kind: str | None[原始碼]#

回傳分片的記憶體種類。

屬性 num_devices: int[原始碼]#

分片包含的裝置數量。

with_memory_kind(kind)[原始碼]#

回傳具有指定記憶體種類的新 Sharding 實例。

參數:

kind (str)

回傳類型:

PositionalSharding

類別 jax.sharding.PmapSharding#

基底類別:Sharding

描述由 jax.pmap() 使用的分片。

classmethod default(shape, sharded_dim=0, devices=None)[原始碼]#

建立一個 PmapSharding,其符合 jax.pmap() 使用的預設放置方式。

參數:
  • shape (Shape) – 輸入陣列的形狀。

  • sharded_dim (整數 | None) – 輸入陣列被分片的維度。預設為 0。

  • devices (序列[xc.Device] | None | None) – 可選的裝置序列以供使用。如果省略,則為隱含的

  • used (pmap 使用的裝置順序為) – jax.local_devices()

  • of (其順序為) – jax.local_devices()

回傳類型:

PmapSharding

屬性 device_set: set[Device]#

Sharding 跨越的裝置集合。

在多控制器 JAX 中,裝置集合是全域的,即包含來自其他進程的不可定址裝置。

屬性 devices#

(self) -> ndarray

devices_indices_map(global_shape)[原始碼]#

回傳從裝置到每個裝置包含的陣列切片的映射。

此映射包含所有全域裝置,即包含來自其他進程的不可定址裝置。

參數:

global_shape (Shape)

回傳類型:

Mapping[Device, Index]

is_equivalent_to(other, ndim)[原始碼]#

如果兩個分片相等,則回傳 True

如果兩個分片將相同的邏輯陣列分片放置在相同的裝置上,則它們是相等的。

例如,如果 NamedShardingPositionalSharding 都將陣列的相同分片放置在相同的裝置上,則它們可能是相等的。

參數:
回傳類型:

bool

屬性 is_fully_addressable: bool#

此分片是否完全可定址?

如果目前進程可以定址 Sharding 中命名的所有裝置,則分片是完全可定址的。在多進程 JAX 中,is_fully_addressable 等同於 “is_local”。

屬性 is_fully_replicated: bool#

此分片是否完全複製?

如果每個裝置都擁有完整資料的完整副本,則分片是完全複製的。

屬性 memory_kind: str | None[原始碼]#

回傳分片的記憶體種類。

屬性 num_devices: int[原始碼]#

分片包含的裝置數量。

shard_shape(global_shape)[原始碼]#

回傳每個裝置上資料的形狀。

此函式回傳的分片形狀是根據 global_shape 和分片的屬性計算得出的。

參數:

global_shape (Shape)

回傳類型:

Shape

屬性 sharding_spec#

(self) -> jax::ShardingSpec

with_memory_kind(kind)[原始碼]#

回傳具有指定記憶體種類的新 Sharding 實例。

參數:

kind (str)

類別 jax.sharding.GSPMDSharding#

基底類別:Sharding

屬性 device_set: set[Device]#

Sharding 跨越的裝置集合。

在多控制器 JAX 中,裝置集合是全域的,即包含來自其他進程的不可定址裝置。

屬性 is_fully_addressable: bool#

此分片是否完全可定址?

如果目前進程可以定址 Sharding 中命名的所有裝置,則分片是完全可定址的。在多進程 JAX 中,is_fully_addressable 等同於 “is_local”。

屬性 is_fully_replicated: bool#

此分片是否完全複製?

如果每個裝置都擁有完整資料的完整副本,則分片是完全複製的。

屬性 memory_kind: str | None[原始碼]#

回傳分片的記憶體種類。

屬性 num_devices: int[原始碼]#

分片包含的裝置數量。

with_memory_kind(kind)[原始碼]#

回傳具有指定記憶體種類的新 Sharding 實例。

參數:

kind (str)

回傳類型:

GSPMDSharding

類別 jax.sharding.PartitionSpec(*partitions)[原始碼]#

元組,描述如何在裝置網格上分割陣列。

每個元素可以是 None、字串或字串元組。請參閱 jax.sharding.NamedSharding 的文件以了解更多詳細資訊。

此類別的存在是為了讓 JAX 的 pytree 工具程式可以區分分片規格與應視為 pytree 的元組。

類別 jax.sharding.Mesh(devices, axis_names, *, axis_types=None)[原始碼]#

宣告在此管理器範圍內可用的硬體資源。

特別是,所有 axis_names 都成為受管理區塊內的有效資源名稱,並且可以用於例如 jax.experimental.pjit.pjit()in_axis_resources 引數中。另請參閱 JAX 的多進程程式設計模型 (https://jax.dev.org.tw/en/latest/multi_process.html) 和分散式陣列與自動平行化教學 (https://jax.dev.org.tw/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)

如果您正在多個執行緒中編譯,請確保 with Mesh 上下文管理器位於執行緒將執行的函數內。

參數:
  • devices (np.ndarray) – 包含 JAX 裝置物件的 NumPy ndarray 物件 (例如從 jax.devices() 取得)。

  • axis_names (tuple[MeshAxisName, ...]) – 要指派給裝置引數維度的資源軸名稱序列。其長度應符合 devices 的秩。

  • axis_types (MeshAxisType)

範例

>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> inp = np.arange(16).reshape((8, 2))
>>> devices = np.array(jax.devices()).reshape(4, 2)
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> # Use the mesh object directly as a context manager.
>>> with global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Initialize the Mesh and use the mesh as the context manager.
>>> with Mesh(devices, ('x', 'y')) as global_mesh:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # Also you can use it as `with ... as ...`.
>>> global_mesh = Mesh(devices, ('x', 'y'))
>>> with global_mesh as m:
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)
>>> # You can also use it as `with Mesh(...)`.
>>> with Mesh(devices, ('x', 'y')):
...   out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(inp)