傳輸保護#

JAX 可能在類型轉換和輸入分片期間,於主機和裝置之間以及裝置之間傳輸資料。為了記錄或禁止任何非預期的傳輸,使用者可以設定 JAX 傳輸保護。

JAX 傳輸保護區分兩種傳輸類型

  • 顯式傳輸:jax.device_put*()jax.device_get() 呼叫。

  • 隱式傳輸:其他傳輸 (例如,列印 DeviceArray)。

傳輸保護可以根據其保護等級採取動作

  • "allow":靜默允許所有傳輸 (預設)。

  • "log":記錄並允許隱式傳輸。靜默允許顯式傳輸。

  • "disallow":禁止隱式傳輸。靜默允許顯式傳輸。

  • "log_explicit":記錄並允許所有傳輸。

  • "disallow_explicit":禁止所有傳輸。

當禁止傳輸時,JAX 將引發 RuntimeError

傳輸保護使用標準 JAX 配置系統

  • --jax_transfer_guard=GUARD_LEVEL 命令列標記和 jax.config.update("jax_transfer_guard", GUARD_LEVEL) 將設定全域選項。

  • with jax.transfer_guard(GUARD_LEVEL): ... 內容管理器將在內容管理器的範圍內設定執行緒本地選項。

請注意,與其他 JAX 配置選項類似,新產生的執行緒將使用全域選項,而不是產生執行緒的範圍的任何活動執行緒本地選項。

傳輸保護也可以根據傳輸方向更有選擇性地應用。標記和內容管理器名稱會附加對應的傳輸方向後綴 (例如,--jax_transfer_guard_host_to_devicejax.config.transfer_guard_host_to_device)

  • "host_to_device":將 Python 值或 NumPy 陣列轉換為 JAX 裝置上緩衝區。

  • "device_to_device":將 JAX 裝置上緩衝區複製到不同的裝置。

  • "device_to_host":擷取 JAX 裝置上緩衝區。

無論傳輸保護等級如何,始終允許擷取 CPU 裝置上的緩衝區。

以下顯示使用傳輸保護的範例。

>>> jax.config.update("jax_transfer_guard", "allow")  # This is default.
>>>
>>> x = jnp.array(1)
>>> y = jnp.array(2)
>>> z = jnp.array(3)
>>>
>>> print("x", x)  # All transfers are allowed.
x 1
>>> with jax.transfer_guard("disallow"):
...   print("x", x)  # x has already been fetched into the host.
...   print("y", jax.device_get(y))  # Explicit transfers are allowed.
...   try:
...     print("z", z)  # Implicit transfers are disallowed.
...     assert False, "This line is expected to be unreachable."
...   except:
...     print("z could not be fetched")  
x 1
y 2
z could not be fetched