jax.experimental.io_callback#

jax.experimental.io_callback(callback, result_shape_dtypes, *args, sharding=None, ordered=False, **kwargs)[原始碼]#

呼叫不純的 Python 回呼函式。

更多說明,請參閱外部回呼

參數:
  • callback (Callable[..., Any]) – 要在主機上執行的函式。假設它是不純函式。如果 callback 是純函式,則使用 jax.pure_callback() 可能會帶來更有效率的執行。

  • result_shape_dtypes (Any) – pytree,其葉節點具有 shapedtype 屬性,其結構符合執行階段回呼函式的預期輸出。jax.ShapeDtypeStruct 通常用於定義葉節點值。

  • *args (Any) – 要傳遞給回呼函式的引數

  • sharding (SingleDeviceSharding | None | None) – 選擇性的分片,指定應從哪個裝置調用回呼。

  • ordered (bool) – 布林值,指定對回呼的循序呼叫是否必須排序。

  • **kwargs (Any) – 要傳遞給回呼函式的關鍵字引數

傳回:

結構符合

result_shape_dtypes 的 jax.Array 物件的 pytree.

傳回類型:

result

參見