jax.Array 遷移#
yashkatariya@
摘要#
JAX 從 0.4.1 版本開始,將其預設陣列實作切換為新的 jax.Array
。本指南說明了此舉背後的理由、可能對您的程式碼造成的影響,以及如何(暫時)切換回舊行為。
發生了什麼事?#
jax.Array
是一種統一的陣列類型,它包含 JAX 中的 DeviceArray
、ShardedDeviceArray
和 GlobalDeviceArray
類型。jax.Array
類型有助於使平行處理成為 JAX 的核心功能,簡化和統一 JAX 內部機制,並讓我們能夠統一 jit 和 pjit。如果您的程式碼沒有提及 DeviceArray
與 ShardedDeviceArray
與 GlobalDeviceArray
,則無需進行任何變更。但是,依賴這些個別類別詳細資訊的程式碼可能需要進行調整,才能與統一的 jax.Array 搭配使用
遷移完成後,jax.Array
將成為 JAX 中唯一的陣列類型。
本文檔說明如何將現有的程式碼庫遷移到 jax.Array
。有關使用 jax.Array
和 JAX 平行處理 API 的更多資訊,請參閱分散式陣列和自動平行化教學課程。
如何啟用 jax.Array?#
您可以透過以下方式啟用 jax.Array
將 shell 環境變數
JAX_ARRAY
設定為類似 true 的值(例如,1
);如果您的程式碼使用 absl 解析標誌,則將布林標誌
jax_array
設定為類似 true 的值;在您的主檔案頂部使用此陳述式
import jax jax.config.update('jax_array', True)
我如何知道 jax.Array 是否破壞了我的程式碼?#
判斷 jax.Array
是否是任何問題的根源的最簡單方法是停用 jax.Array
,並查看問題是否消失。
我現在可以如何停用 jax.Array?#
在 2023 年 3 月 15 日 之前,可以透過以下方式停用 jax.Array
將 shell 環境變數
JAX_ARRAY
設定為類似 false 的值(例如,0
);如果您的程式碼使用 absl 解析標誌,則將布林標誌
jax_array
設定為類似 false 的值;在您的主檔案頂部使用此陳述式
import jax jax.config.update('jax_array', False)
為何建立 jax.Array?#
目前 JAX 有三種類型:DeviceArray
、ShardedDeviceArray
和 GlobalDeviceArray
。jax.Array
合併了這三種類型,並清理了 JAX 的內部機制,同時新增了新的平行處理功能。
我們還引入了新的 Sharding
抽象概念,用於描述邏輯陣列如何在一個或多個裝置(例如 TPU 或 GPU)上進行物理分片。此變更還升級、簡化和合併了 pjit
的平行處理功能到 jit
中。使用 jit
修飾的函式將能夠對分片陣列進行操作,而無需將資料複製到單一裝置上。
您透過 jax.Array
獲得的功能
C++
pjit
分派路徑逐運算元平行處理(即使陣列分散在多個主機的多個裝置上)
使用
pjit
/jit
更簡單的批次資料平行處理。建立不一定由網格和分割規格組成的
Sharding
的方法。如果您願意,可以充分利用 OpSharding 的彈性,或任何其他您想要的 Sharding。以及更多
範例
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
x = jnp.arange(8)
# Let's say there are 8 devices in jax.devices()
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
sharded_x = jax.device_put(x, sharding)
# `matmul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
# sharded array without copying data to a single device.
matmul_sharded_x = sharded_x @ sharded_x.T
sin_sharded_x = jnp.sin(sharded_x)
# Even jnp.copy preserves the sharding on the output.
copy_sharded_x = jnp.copy(sharded_x)
# double_out is also sharded
double_out = jax.jit(lambda x: x * 2)(sharded_x)
當 jax.Array 開啟時,可能會出現哪些問題?#
名為 jax.Array 的新公共類型#
所有 isinstance(..., jnp.DeviceArray)
或 isinstance(.., jax.xla.DeviceArray)
和其他 DeviceArray
的變體都應切換為使用 isinstance(..., jax.Array)
。
由於 jax.Array
可以表示 DA、SDA 和 GDA,因此您可以透過以下方式在 jax.Array
中區分這 3 種類型
x.is_fully_addressable and len(x.sharding.device_set) == 1
– 這表示jax.Array
類似於 DAx.is_fully_addressable and (len(x.sharding.device_set) > 1
– 這表示jax.Array
類似於 SDAnot x.is_fully_addressable
– 這表示jax.Array
類似於 GDA 且跨越多個進程
對於 ShardedDeviceArray
,您可以將 isinstance(..., pxla.ShardedDeviceArray)
移至 isinstance(..., jax.Array) and x.is_fully_addressable and len(x.sharding.device_set) > 1
。
一般而言,無法區分單一裝置上的 ShardedDeviceArray
與任何其他種類的單一裝置陣列。
GDA 的 API 名稱變更#
GDA 的 local_shards
和 local_data
已被棄用。
請使用與 jax.Array
和 GDA
相容的 addressable_shards
和 addressable_data
。
建立 jax.Array#
當 jax_array
標誌為 True 時,所有 JAX 函式都將輸出 jax.Array
。如果您使用 GlobalDeviceArray.from_callback
或 make_sharded_device_array
或 make_device_array
函式來明確建立各自的 JAX 資料類型,則需要將它們切換為使用 jax.make_array_from_callback()
或 jax.make_array_from_single_device_arrays()
。
對於 GDA
GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)
可以變成 jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)
,以進行 1 對 1 切換。
如果您使用原始 GDA 建構函式來建立 GDA,請執行以下操作
GlobalDeviceArray(shape, mesh, pspec, buffers)
可以變成 jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)
對於 SDA
make_sharded_device_array(aval, sharding_spec, device_buffers, indices)
可以變成 jax.make_array_from_single_device_arrays(shape, sharding, device_buffers)
。
要決定 sharding 應該是什麼,這取決於您為何建立 SDA
如果是為了作為 pmap
的輸入而建立,則 sharding 可以是:jax.sharding.PmapSharding(devices, sharding_spec)
。
如果是為了作為 pjit
的輸入而建立,則 sharding 可以是 jax.sharding.NamedSharding(mesh, pspec)
。
切換到 jax.Array 以處理主機本機輸入後,pjit 的重大變更#
如果您完全使用 GDA 引數到 pjit,則可以跳過本節! 🎉
啟用 jax.Array
後,pjit
的所有輸入都必須是全域形狀。這是與先前行為的重大變更,先前行為中 pjit
會將進程本機引數串連成全域值;此串連不再發生。
我們為何進行此重大變更?現在每個陣列都明確說明其本機分片如何融入全域整體,而不是讓它隱含。更明確的表示法也解鎖了額外的彈性,例如將非連續網格與 pjit
搭配使用,這可以提高某些 TPU 模型上的效率。
當啟用 jax.Array
時,執行 多進程 pjit 計算 並傳遞主機本機輸入可能會導致類似於此的錯誤
範例
網格 = {'x': 2, 'y': 2, 'z': 2}
和主機本機輸入形狀 == (4,)
和 pspec = P(('x', 'y', 'z'))
由於 pjit
不使用 jax.Array
將主機本機形狀提升為全域形狀,因此您會收到以下錯誤
注意:僅當您的主機本機形狀小於網格的形狀時,您才會看到此錯誤。
ValueError: One of pjit arguments was given the sharding of
NamedSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
which implies that the global size of its dimension 0 should be divisible by 8,
but it is equal to 4
此錯誤是有道理的,因為當維度 0
上的值為 4
時,您無法將維度 0 分片 8 種方式。
如果您仍然將主機本機輸入傳遞給 pjit
,您該如何遷移?我們正在提供過渡性 API 來協助您遷移
注意:如果您在單一進程上執行您的 pjitted 計算,則不需要這些公用程式。
from jax.experimental import multihost_utils
global_inps = multihost_utils.host_local_array_to_global_array(
local_inputs, mesh, in_pspecs)
global_outputs = pjit(f, in_shardings=in_pspecs,
out_shardings=out_pspecs)(global_inps)
local_outs = multihost_utils.global_array_to_host_local_array(
global_outputs, mesh, out_pspecs)
host_local_array_to_global_array
是一種類型轉換,它會查看僅具有本機分片的值,並將其本機形狀變更為 pjit
在變更之前假設傳遞該值時的形狀。
仍然支援傳遞完全複製的輸入,即每個進程上的形狀相同,並以 P(None)
作為 in_axis_resources
。在這種情況下,您不必使用 host_local_array_to_global_array
,因為形狀已經是全域的。
key = jax.random.PRNGKey(1)
# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
# that the input is fully replicated via P(None)
pjit(f, in_shardings=None, out_shardings=None)(key)
# Mixing inputs
global_inp = multihost_utils.host_local_array_to_global_array(
local_inp, mesh, P('data'))
global_out = pjit(f, in_shardings=(P(None), P('data')),
out_shardings=...)(key, global_inp)
FROM_GDA 和 jax.Array#
如果您在 pjit
的 in_axis_resources
引數中使用 FROM_GDA
,則使用 jax.Array
時,無需將任何內容傳遞到 in_axis_resources
,因為 jax.Array
將遵循計算遵循 sharding 語義。
例如
pjit(f, in_shardings=FROM_GDA, out_shardings=...) can be replaced by pjit(f, out_shardings=...)
如果您的 PartitionSpecs 與用於 numpy 陣列等輸入的 FROM_GDA
混合在一起,則使用 host_local_array_to_global_array
將它們轉換為 jax.Array
。
例如
如果您有這個
pjitted_f = pjit(
f, in_shardings=(FROM_GDA, P('x'), FROM_GDA, P(None)),
out_shardings=...)
pjitted_f(gda1, np_array1, gda2, np_array2)
那麼您可以將其替換為
pjitted_f = pjit(f, out_shardings=...)
array2, array3 = multihost_utils.host_local_array_to_global_array(
(np_array1, np_array2), mesh, (P('x'), P(None)))
pjitted_f(array1, array2, array3, array4)
live_buffers 已替換為 live_arrays#
jax Device
上的 live_buffers
屬性已被棄用。請改用與 jax.Array
相容的 jax.live_arrays()
。
處理批次等主機本機輸入到 pjit#
如果您在多進程環境中將主機本機輸入傳遞到 pjit
,請使用 multihost_utils.host_local_array_to_global_array
將批次轉換為全域 jax.Array
,然後將其傳遞到 pjit
。
此類主機本機輸入最常見的範例是輸入資料批次。
這適用於任何主機本機輸入(而不僅僅是輸入資料批次)。
from jax.experimental import multihost_utils
batch = multihost_utils.host_local_array_to_global_array(
batch, mesh, batch_partition_spec)
有關此變更和更多範例的詳細資訊,請參閱上面的 pjit 章節。
RecursionError:遞迴呼叫 jit#
當您的程式碼的某些部分停用了 jax.Array
,然後您僅針對其他部分啟用它時,就會發生這種情況。例如,如果您使用某些停用了 jax.Array
的第三方程式碼,並且您從該程式庫取得 DeviceArray
,然後您在您的程式庫中啟用 jax.Array
,並將該 DeviceArray
傳遞給 JAX 函式,則會導致 RecursionError。
當預設啟用 jax.Array
,以便所有程式庫都傳回 jax.Array
,除非它們明確停用它時,此錯誤應該會消失。