jax.experimental.multihost_utils.global_array_to_host_local_array#

jax.experimental.multihost_utils.global_array_to_host_local_array(global_inputs, global_mesh, pspecs)[來源]#

將全域 jax.Array 轉換為主機本機 jax.Array

您可以使用此函數轉換到 jax.Array。搭配 pjit 使用 jax.Array 具有與搭配 pjit 使用 GDA 相同的語意,即 pjit 的所有 jax.Array 輸入都應為全域形狀,而來自 pjit 的輸出也將是全域形狀的 jax.Array

您可以使用此函數將來自 pjit 的全域形狀 jax.Array 輸出再次轉換為主機本機值,以便轉換到 jax.Array 可以是一個機械式的變更。

使用範例

>>> from jax.experimental import multihost_utils 
>>>
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) 
>>>
>>> with mesh: 
...   global_out = pjitted_fun(global_inputs) 
>>>
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) 
參數:
傳回:

主機本機陣列的 Pytree。