jax.pure_callback#

jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=Deprecated, vmap_method=None, **kwargs)[原始碼]#

呼叫純 Python 回呼。適用於 jit()/vmap()/等等。

如需更多說明,請參閱外部回呼

pure_callback 允許在 JIT 編譯的 JAX 函式中呼叫 Python 函式。輸入 callback 將傳遞放置在本地 CPU 上的 JAX 陣列,並且也應傳回 CPU 上的 JAX 陣列。

回呼被視為功能上純粹的,表示它沒有副作用,並且其輸出值僅取決於其參數值。因此,多次呼叫它是安全的(例如,當由 vmap()pmap() 轉換時),或者當例如 jit 裝飾函式的輸出對其值沒有資料依賴性時,可以完全不呼叫。如果資料依賴性允許,純回呼也可以重新排序。

vmap 化時,行為將取決於 vmap_method 的值。

  • 在沒有明確 vmap_method 的情況下,對回呼呼叫 vmap() 已被棄用,最終將引發 NotImplementedError

  • vmap_method="sequential" 使用 map() 迴圈遍歷批次參數,為每個批次元素呼叫一次 callback

  • vmap_method="expand_dims" 使用大小為 1 的新軸呼叫 callback,這些軸作為前導維度新增至未批次輸入。

  • vmap_method="broadcast_all" 的行為類似於 expand_dims,但輸入會平鋪到預期的批次形狀。

如有必要,可以使用 vmap_method="legacy_vectorized" 還原已棄用的 vectorized=True 參數提供的舊版行為。

目前的預設行為是在未指定時使用 vmap_method="sequential",但此行為已被棄用,未來,除非明確指定 vmap_method,否則預設行為將是引發 NotImplementedError

參數:
  • callback (Callable[..., Any]) – 要在主機上執行的函式。回呼假定為純函式(即沒有副作用的函式):如果傳遞了不純函式,它可能會以意外的方式運作,尤其是在轉換下。可呼叫物件將傳遞陣列的 PyTrees 作為參數,並且應傳回與 result_shape_dtypes 相符的陣列 PyTree。

  • result_shape_dtypes (Any) – pytree,其葉節點具有 shapedtype 屬性,其結構與執行階段回呼函式的預期輸出相符。jax.ShapeDtypeStruct 通常用於定義葉節點值。

  • *args (Any) – 要傳遞給回呼函式的參數

  • sharding (SingleDeviceSharding | None | None) – 可選分片,指定應從哪個裝置調用回呼。

  • vmap_method (str | None | None) – 字串,指定回呼如何在 vmap() 下轉換,如上所述。

  • **kwargs (Any) – 要傳遞給回呼函式的關鍵字參數

  • vectorized (bool | None | DeprecatedArg)

傳回:

一個 jax.Array 物件的 pytree,其結構與

result_shape_dtypes.

傳回類型:

result

參見

範例

pure_callbackvmap() 下的行為由 vmap_method 參數控制,如上所述。考慮一些明確的範例來示範語義是很有用的。例如,考慮以下函式

>>> def callback(x, y):
...   print(jnp.shape(x), jnp.shape(y))
...   return x + y
>>> def fun(x, y, *, vmap_method):
...   shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y))
...   dtype = jnp.result_type(x, y)
...   out_type = jax.ShapeDtypeStruct(shape, dtype)
...   return jax.pure_callback(callback, out_type, x, y,
...                            vmap_method=vmap_method)

使用 vmap_method="expand_dims" 呼叫此函式會將大小為 1 的新軸新增至 y

>>> from functools import partial
>>> x = jnp.arange(4)
>>> y = 1.0
>>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y)
(4,) (1,)
Array([1., 2., 3., 4.], dtype=float32)

然而,vmap_method="broadcast_all" 會將大小為 4 的軸新增至 y

>>> jax.vmap(partial(fun, vmap_method="broadcast_all"),
...          in_axes=(0, None))(x, y)
(4,) (4,)
Array([1., 2., 3., 4.], dtype=float32)