jax.pure_callback#
- jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=Deprecated, vmap_method=None, **kwargs)[原始碼]#
呼叫純 Python 回呼。適用於
jit()
/vmap()
/等等。如需更多說明,請參閱外部回呼。
pure_callback
允許在 JIT 編譯的 JAX 函式中呼叫 Python 函式。輸入callback
將傳遞放置在本地 CPU 上的 JAX 陣列,並且也應傳回 CPU 上的 JAX 陣列。回呼被視為功能上純粹的,表示它沒有副作用,並且其輸出值僅取決於其參數值。因此,多次呼叫它是安全的(例如,當由
vmap()
或pmap()
轉換時),或者當例如 jit 裝飾函式的輸出對其值沒有資料依賴性時,可以完全不呼叫。如果資料依賴性允許,純回呼也可以重新排序。當 vmap 化時,行為將取決於
vmap_method
的值。在沒有明確
vmap_method
的情況下,對回呼呼叫vmap()
已被棄用,最終將引發NotImplementedError
。vmap_method="sequential"
使用map()
迴圈遍歷批次參數,為每個批次元素呼叫一次callback
。vmap_method="expand_dims"
使用大小為1
的新軸呼叫callback
,這些軸作為前導維度新增至未批次輸入。vmap_method="broadcast_all"
的行為類似於expand_dims
,但輸入會平鋪到預期的批次形狀。
如有必要,可以使用
vmap_method="legacy_vectorized"
還原已棄用的vectorized=True
參數提供的舊版行為。目前的預設行為是在未指定時使用
vmap_method="sequential"
,但此行為已被棄用,未來,除非明確指定vmap_method
,否則預設行為將是引發NotImplementedError
。- 參數:
callback (Callable[..., Any]) – 要在主機上執行的函式。回呼假定為純函式(即沒有副作用的函式):如果傳遞了不純函式,它可能會以意外的方式運作,尤其是在轉換下。可呼叫物件將傳遞陣列的 PyTrees 作為參數,並且應傳回與
result_shape_dtypes
相符的陣列 PyTree。result_shape_dtypes (Any) – pytree,其葉節點具有
shape
和dtype
屬性,其結構與執行階段回呼函式的預期輸出相符。jax.ShapeDtypeStruct
通常用於定義葉節點值。*args (Any) – 要傳遞給回呼函式的參數
sharding (SingleDeviceSharding | None | None) – 可選分片,指定應從哪個裝置調用回呼。
vmap_method (str | None | None) – 字串,指定回呼如何在
vmap()
下轉換,如上所述。**kwargs (Any) – 要傳遞給回呼函式的關鍵字參數
vectorized (bool | None | DeprecatedArg)
- 傳回:
- 一個
jax.Array
物件的 pytree,其結構與 result_shape_dtypes
.
- 一個
- 傳回類型:
result
參見
jax.experimental.io_callback()
:專為不純函式設計的回呼。jax.debug.callback()
:專為通用除錯設計的回呼。jax.debug.print()
:專為列印設計的回呼。
範例
pure_callback
在vmap()
下的行為由vmap_method
參數控制,如上所述。考慮一些明確的範例來示範語義是很有用的。例如,考慮以下函式>>> def callback(x, y): ... print(jnp.shape(x), jnp.shape(y)) ... return x + y
>>> def fun(x, y, *, vmap_method): ... shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y)) ... dtype = jnp.result_type(x, y) ... out_type = jax.ShapeDtypeStruct(shape, dtype) ... return jax.pure_callback(callback, out_type, x, y, ... vmap_method=vmap_method)
使用
vmap_method="expand_dims"
呼叫此函式會將大小為1
的新軸新增至y
>>> from functools import partial >>> x = jnp.arange(4) >>> y = 1.0 >>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y) (4,) (1,) Array([1., 2., 3., 4.], dtype=float32)
然而,
vmap_method="broadcast_all"
會將大小為4
的軸新增至y
>>> jax.vmap(partial(fun, vmap_method="broadcast_all"), ... in_axes=(0, None))(x, y) (4,) (4,) Array([1., 2., 3., 4.], dtype=float32)