jax.experimental.io_callback#
- jax.experimental.io_callback(callback, result_shape_dtypes, *args, sharding=None, ordered=False, **kwargs)[原始碼]#
呼叫不純的 Python 回呼函式。
更多說明,請參閱外部回呼。
- 參數:
callback (Callable[..., Any]) – 要在主機上執行的函式。假設它是不純函式。如果
callback
是純函式,則使用jax.pure_callback()
可能會帶來更有效率的執行。result_shape_dtypes (Any) – pytree,其葉節點具有
shape
和dtype
屬性,其結構符合執行階段回呼函式的預期輸出。jax.ShapeDtypeStruct
通常用於定義葉節點值。*args (Any) – 要傳遞給回呼函式的引數
sharding (SingleDeviceSharding | None | None) – 選擇性的分片,指定應從哪個裝置調用回呼。
ordered (bool) – 布林值,指定對回呼的循序呼叫是否必須排序。
**kwargs (Any) – 要傳遞給回呼函式的關鍵字引數
- 傳回:
- 結構符合
result_shape_dtypes 的
.jax.Array
物件的 pytree
- 傳回類型:
result
參見
jax.pure_callback()
:為純函式設計的回呼。jax.debug.callback()
:為通用除錯設計的回呼。jax.debug.print()
:為列印設計的回呼。