jax.device_count#

jax.device_count(backend=None)[原始碼]#

傳回裝置總數。

在大多數平台上,這與 jax.local_device_count() 相同。然而,在多進程平台上,不同的裝置與不同的進程相關聯,這將傳回所有進程的裝置總數。

參數:

backend (str | xla_client.Client | None | None) – 這是一個實驗性功能,API 可能會變更。選填,一個表示 xla 後端的字串:'cpu''gpu''tpu'

傳回:

裝置數量。

傳回型別:

int