jax.Array 遷移#

yashkatariya@

摘要#

JAX 從 0.4.1 版本開始,將其預設陣列實作切換為新的 jax.Array。本指南說明了此舉背後的理由、可能對您的程式碼造成的影響,以及如何(暫時)切換回舊行為。

發生了什麼事?#

jax.Array 是一種統一的陣列類型,它包含 JAX 中的 DeviceArrayShardedDeviceArrayGlobalDeviceArray 類型。jax.Array 類型有助於使平行處理成為 JAX 的核心功能,簡化和統一 JAX 內部機制,並讓我們能夠統一 jit 和 pjit。如果您的程式碼沒有提及 DeviceArrayShardedDeviceArrayGlobalDeviceArray,則無需進行任何變更。但是,依賴這些個別類別詳細資訊的程式碼可能需要進行調整,才能與統一的 jax.Array 搭配使用

遷移完成後,jax.Array 將成為 JAX 中唯一的陣列類型。

本文檔說明如何將現有的程式碼庫遷移到 jax.Array。有關使用 jax.Array 和 JAX 平行處理 API 的更多資訊,請參閱分散式陣列和自動平行化教學課程。

如何啟用 jax.Array?#

您可以透過以下方式啟用 jax.Array

  • 將 shell 環境變數 JAX_ARRAY 設定為類似 true 的值(例如,1);

  • 如果您的程式碼使用 absl 解析標誌,則將布林標誌 jax_array 設定為類似 true 的值;

  • 在您的主檔案頂部使用此陳述式

    import jax
    jax.config.update('jax_array', True)
    

我如何知道 jax.Array 是否破壞了我的程式碼?#

判斷 jax.Array 是否是任何問題的根源的最簡單方法是停用 jax.Array,並查看問題是否消失。

我現在可以如何停用 jax.Array?#

2023 年 3 月 15 日 之前,可以透過以下方式停用 jax.Array

  • 將 shell 環境變數 JAX_ARRAY 設定為類似 false 的值(例如,0);

  • 如果您的程式碼使用 absl 解析標誌,則將布林標誌 jax_array 設定為類似 false 的值;

  • 在您的主檔案頂部使用此陳述式

    import jax
    jax.config.update('jax_array', False)
    

為何建立 jax.Array?#

目前 JAX 有三種類型:DeviceArrayShardedDeviceArrayGlobalDeviceArrayjax.Array 合併了這三種類型,並清理了 JAX 的內部機制,同時新增了新的平行處理功能。

我們還引入了新的 Sharding 抽象概念,用於描述邏輯陣列如何在一個或多個裝置(例如 TPU 或 GPU)上進行物理分片。此變更還升級、簡化和合併了 pjit 的平行處理功能到 jit 中。使用 jit 修飾的函式將能夠對分片陣列進行操作,而無需將資料複製到單一裝置上。

您透過 jax.Array 獲得的功能

  • C++ pjit 分派路徑

  • 逐運算元平行處理(即使陣列分散在多個主機的多個裝置上)

  • 使用 pjit/jit 更簡單的批次資料平行處理。

  • 建立不一定由網格和分割規格組成的 Sharding 的方法。如果您願意,可以充分利用 OpSharding 的彈性,或任何其他您想要的 Sharding。

  • 以及更多

範例

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
x = jnp.arange(8)

# Let's say there are 8 devices in jax.devices()
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))

sharded_x = jax.device_put(x, sharding)

# `matmul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
# sharded array without copying data to a single device.
matmul_sharded_x = sharded_x @ sharded_x.T
sin_sharded_x = jnp.sin(sharded_x)

# Even jnp.copy preserves the sharding on the output.
copy_sharded_x = jnp.copy(sharded_x)

# double_out is also sharded
double_out = jax.jit(lambda x: x * 2)(sharded_x)

當 jax.Array 開啟時,可能會出現哪些問題?#

名為 jax.Array 的新公共類型#

所有 isinstance(..., jnp.DeviceArray)isinstance(.., jax.xla.DeviceArray) 和其他 DeviceArray 的變體都應切換為使用 isinstance(..., jax.Array)

由於 jax.Array 可以表示 DA、SDA 和 GDA,因此您可以透過以下方式在 jax.Array 中區分這 3 種類型

  • x.is_fully_addressable and len(x.sharding.device_set) == 1 – 這表示 jax.Array 類似於 DA

  • x.is_fully_addressable and (len(x.sharding.device_set) > 1 – 這表示 jax.Array 類似於 SDA

  • not x.is_fully_addressable – 這表示 jax.Array 類似於 GDA 且跨越多個進程

對於 ShardedDeviceArray,您可以將 isinstance(..., pxla.ShardedDeviceArray) 移至 isinstance(..., jax.Array) and x.is_fully_addressable and len(x.sharding.device_set) > 1

一般而言,無法區分單一裝置上的 ShardedDeviceArray 與任何其他種類的單一裝置陣列。

GDA 的 API 名稱變更#

GDA 的 local_shardslocal_data 已被棄用。

請使用與 jax.ArrayGDA 相容的 addressable_shardsaddressable_data

建立 jax.Array#

jax_array 標誌為 True 時,所有 JAX 函式都將輸出 jax.Array。如果您使用 GlobalDeviceArray.from_callbackmake_sharded_device_arraymake_device_array 函式來明確建立各自的 JAX 資料類型,則需要將它們切換為使用 jax.make_array_from_callback()jax.make_array_from_single_device_arrays()

對於 GDA

GlobalDeviceArray.from_callback(shape, mesh, pspec, callback) 可以變成 jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback),以進行 1 對 1 切換。

如果您使用原始 GDA 建構函式來建立 GDA,請執行以下操作

GlobalDeviceArray(shape, mesh, pspec, buffers) 可以變成 jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)

對於 SDA

make_sharded_device_array(aval, sharding_spec, device_buffers, indices) 可以變成 jax.make_array_from_single_device_arrays(shape, sharding, device_buffers)

要決定 sharding 應該是什麼,這取決於您為何建立 SDA

如果是為了作為 pmap 的輸入而建立,則 sharding 可以是:jax.sharding.PmapSharding(devices, sharding_spec)

如果是為了作為 pjit 的輸入而建立,則 sharding 可以是 jax.sharding.NamedSharding(mesh, pspec)

切換到 jax.Array 以處理主機本機輸入後,pjit 的重大變更#

如果您完全使用 GDA 引數到 pjit,則可以跳過本節! 🎉

啟用 jax.Array 後,pjit 的所有輸入都必須是全域形狀。這是與先前行為的重大變更,先前行為中 pjit 會將進程本機引數串連成全域值;此串連不再發生。

我們為何進行此重大變更?現在每個陣列都明確說明其本機分片如何融入全域整體,而不是讓它隱含。更明確的表示法也解鎖了額外的彈性,例如將非連續網格與 pjit 搭配使用,這可以提高某些 TPU 模型上的效率。

當啟用 jax.Array 時,執行 多進程 pjit 計算 並傳遞主機本機輸入可能會導致類似於此的錯誤

範例

網格 = {'x': 2, 'y': 2, 'z': 2} 和主機本機輸入形狀 == (4,) 和 pspec = P(('x', 'y', 'z'))

由於 pjit 不使用 jax.Array 將主機本機形狀提升為全域形狀,因此您會收到以下錯誤

注意:僅當您的主機本機形狀小於網格的形狀時,您才會看到此錯誤。

ValueError: One of pjit arguments was given the sharding of
NamedSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
which implies that the global size of its dimension 0 should be divisible by 8,
but it is equal to 4

此錯誤是有道理的,因為當維度 0 上的值為 4 時,您無法將維度 0 分片 8 種方式。

如果您仍然將主機本機輸入傳遞給 pjit,您該如何遷移?我們正在提供過渡性 API 來協助您遷移

注意:如果您在單一進程上執行您的 pjitted 計算,則不需要這些公用程式。

from jax.experimental import multihost_utils

global_inps = multihost_utils.host_local_array_to_global_array(
    local_inputs, mesh, in_pspecs)

global_outputs = pjit(f, in_shardings=in_pspecs,
                      out_shardings=out_pspecs)(global_inps)

local_outs = multihost_utils.global_array_to_host_local_array(
    global_outputs, mesh, out_pspecs)

host_local_array_to_global_array 是一種類型轉換,它會查看僅具有本機分片的值,並將其本機形狀變更為 pjit 在變更之前假設傳遞該值時的形狀。

仍然支援傳遞完全複製的輸入,即每個進程上的形狀相同,並以 P(None) 作為 in_axis_resources。在這種情況下,您不必使用 host_local_array_to_global_array,因為形狀已經是全域的。

key = jax.random.PRNGKey(1)

# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
# that the input is fully replicated via P(None)
pjit(f, in_shardings=None, out_shardings=None)(key)

# Mixing inputs
global_inp = multihost_utils.host_local_array_to_global_array(
    local_inp, mesh, P('data'))
global_out = pjit(f, in_shardings=(P(None), P('data')),
                  out_shardings=...)(key, global_inp)

FROM_GDA 和 jax.Array#

如果您在 pjitin_axis_resources 引數中使用 FROM_GDA,則使用 jax.Array 時,無需將任何內容傳遞到 in_axis_resources,因為 jax.Array 將遵循計算遵循 sharding 語義。

例如

pjit(f, in_shardings=FROM_GDA, out_shardings=...) can be replaced by pjit(f, out_shardings=...)

如果您的 PartitionSpecs 與用於 numpy 陣列等輸入的 FROM_GDA 混合在一起,則使用 host_local_array_to_global_array 將它們轉換為 jax.Array

例如

如果您有這個

pjitted_f = pjit(
    f, in_shardings=(FROM_GDA, P('x'), FROM_GDA, P(None)),
    out_shardings=...)
pjitted_f(gda1, np_array1, gda2, np_array2)

那麼您可以將其替換為


pjitted_f = pjit(f, out_shardings=...)

array2, array3 = multihost_utils.host_local_array_to_global_array(
    (np_array1, np_array2), mesh, (P('x'), P(None)))

pjitted_f(array1, array2, array3, array4)

live_buffers 已替換為 live_arrays#

jax Device 上的 live_buffers 屬性已被棄用。請改用與 jax.Array 相容的 jax.live_arrays()

處理批次等主機本機輸入到 pjit#

如果您在多進程環境中將主機本機輸入傳遞到 pjit,請使用 multihost_utils.host_local_array_to_global_array 將批次轉換為全域 jax.Array,然後將其傳遞到 pjit

此類主機本機輸入最常見的範例是輸入資料批次

這適用於任何主機本機輸入(而不僅僅是輸入資料批次)。

from jax.experimental import multihost_utils

batch = multihost_utils.host_local_array_to_global_array(
    batch, mesh, batch_partition_spec)

有關此變更和更多範例的詳細資訊,請參閱上面的 pjit 章節。

RecursionError:遞迴呼叫 jit#

當您的程式碼的某些部分停用了 jax.Array,然後您僅針對其他部分啟用它時,就會發生這種情況。例如,如果您使用某些停用了 jax.Array 的第三方程式碼,並且您從該程式庫取得 DeviceArray,然後您在您的程式庫中啟用 jax.Array,並將該 DeviceArray 傳遞給 JAX 函式,則會導致 RecursionError。

當預設啟用 jax.Array,以便所有程式庫都傳回 jax.Array,除非它們明確停用它時,此錯誤應該會消失。