秩提升警告#
NumPy 廣播規則允許自動將引數從一個秩(陣列軸的數量)提升到另一個秩。當本意如此時,此行為可能很方便,但也可能導致令人驚訝的錯誤,其中靜默的秩提升會掩蓋底層的形狀錯誤。
以下是秩提升的範例
>>> import numpy as np
>>> x = np.arange(12).reshape(4, 3)
>>> y = np.array([0, 1, 0])
>>> x + y
array([[ 0, 2, 2],
[ 3, 5, 5],
[ 6, 8, 8],
[ 9, 11, 11]])
為了避免潛在的意外,jax.numpy
是可配置的,因此需要秩提升的表達式可能會導致警告、錯誤,或者可以像常規 NumPy 一樣被允許。配置選項名為 jax_numpy_rank_promotion
,它可以採用字串值 allow
、warn
和 raise
。預設設定為 allow
,它允許秩提升,而不會發出警告或錯誤。raise
設定會在秩提升時引發錯誤,而 warn
會在第一次發生秩提升時發出警告。
秩提升可以使用 jax.numpy_rank_promotion()
上下文管理器在本機啟用或停用
with jax.numpy_rank_promotion("warn"):
z = x + y
此配置也可以通過幾種方式全域設定。一種是在您的程式碼中使用 jax.config
import jax
jax.config.update("jax_numpy_rank_promotion", "warn")
您也可以使用環境變數 JAX_NUMPY_RANK_PROMOTION
設定此選項,例如 JAX_NUMPY_RANK_PROMOTION='warn'
。最後,當使用 absl-py
時,可以使用命令列旗標設定此選項。