jax.make_array_from_callback#

jax.make_array_from_callback(shape, sharding, data_callback)[原始碼]#

透過從 data_callback 提取的資料,傳回 jax.Array

data_callback 用於提取傳回的 jax.Array 之每個可定址分片的資料。此函式必須傳回具體陣列,表示 make_array_from_callback 與 JAX 轉換(如 jit()vmap())的相容性有限。

參數:
  • shape (Shape) – jax.Array 的形狀。

  • sharding (Sharding | Layout) – Sharding 實例,描述 jax.Array 如何在裝置之間佈局。

  • data_callback (Callable[[Index | None], ArrayLike]) – 回呼函式,將全域陣列值的索引作為輸入,並傳回全域陣列值的相應資料。資料可以任何類陣列物件傳回,例如 numpy.ndarray

傳回:

透過從 data_callback 提取的資料,取得 jax.Array

傳回型別:

ArrayImpl

範例

>>> import math
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> input_shape = (8, 8)
>>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
...
>>> def cb(index):
...  return global_input_data[index]
...
>>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb)
>>> arr.addressable_data(0).shape
(4, 2)