傳輸保護#
JAX 可能在類型轉換和輸入分片期間,於主機和裝置之間以及裝置之間傳輸資料。為了記錄或禁止任何非預期的傳輸,使用者可以設定 JAX 傳輸保護。
JAX 傳輸保護區分兩種傳輸類型
顯式傳輸:
jax.device_put*()
和jax.device_get()
呼叫。隱式傳輸:其他傳輸 (例如,列印
DeviceArray
)。
傳輸保護可以根據其保護等級採取動作
"allow"
:靜默允許所有傳輸 (預設)。"log"
:記錄並允許隱式傳輸。靜默允許顯式傳輸。"disallow"
:禁止隱式傳輸。靜默允許顯式傳輸。"log_explicit"
:記錄並允許所有傳輸。"disallow_explicit"
:禁止所有傳輸。
當禁止傳輸時,JAX 將引發 RuntimeError
。
傳輸保護使用標準 JAX 配置系統
命令列標記和--jax_transfer_guard=GUARD_LEVEL
將設定全域選項。jax.config.update("jax_transfer_guard", GUARD_LEVEL)
內容管理器將在內容管理器的範圍內設定執行緒本地選項。with jax.transfer_guard(GUARD_LEVEL): ...
請注意,與其他 JAX 配置選項類似,新產生的執行緒將使用全域選項,而不是產生執行緒的範圍的任何活動執行緒本地選項。
傳輸保護也可以根據傳輸方向更有選擇性地應用。標記和內容管理器名稱會附加對應的傳輸方向後綴 (例如,
和 --jax_transfer_guard_host_to_device
)jax.config.transfer_guard_host_to_device
:將 Python 值或 NumPy 陣列轉換為 JAX 裝置上緩衝區。"host_to_device"
:將 JAX 裝置上緩衝區複製到不同的裝置。"device_to_device"
:擷取 JAX 裝置上緩衝區。"device_to_host"
無論傳輸保護等級如何,始終允許擷取 CPU 裝置上的緩衝區。
以下顯示使用傳輸保護的範例。
>>> jax.config.update("jax_transfer_guard", "allow") # This is default.
>>>
>>> x = jnp.array(1)
>>> y = jnp.array(2)
>>> z = jnp.array(3)
>>>
>>> print("x", x) # All transfers are allowed.
x 1
>>> with jax.transfer_guard("disallow"):
... print("x", x) # x has already been fetched into the host.
... print("y", jax.device_get(y)) # Explicit transfers are allowed.
... try:
... print("z", z) # Implicit transfers are disallowed.
... assert False, "This line is expected to be unreachable."
... except:
... print("z could not be fetched")
x 1
y 2
z could not be fetched