型別提升語意#
本文檔描述 JAX 的型別提升規則,即 jax.numpy.promote_types()
對於每對型別的結果。關於設計以下描述內容的考量背景,請參閱 JAX 型別提升語意設計。
JAX 的型別提升行為透過以下型別提升格狀結構決定
例如
b1
表示np.bool_
,i2
表示np.int16
,u4
表示np.uint32
,bf
表示np.bfloat16
,f2
表示np.float16
,c8
表示np.complex64
,i*
表示 Pythonint
或弱型別int
,f*
表示 Pythonfloat
或弱型別float
,以及c*
表示 Pythoncomplex
或弱型別complex
。
(關於弱型別的更多資訊,請參閱下方的JAX 中的弱型別值)。
任何兩種型別之間的提升由它們在此格狀結構上的 join 給出,這會產生以下二元提升表
b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
b1 | b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* |
u1 | u1 | u1 | u2 | u4 | u8 | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u1 | f* | c* |
u2 | u2 | u2 | u2 | u4 | u8 | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u2 | f* | c* |
u4 | u4 | u4 | u4 | u4 | u8 | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c8 | c16 | u4 | f* | c* |
u8 | u8 | u8 | u8 | u8 | u8 | f* | f* | f* | f* | bf | f2 | f4 | f8 | c8 | c16 | u8 | f* | c* |
i1 | i1 | i2 | i4 | i8 | f* | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i1 | f* | c* |
i2 | i2 | i2 | i4 | i8 | f* | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i2 | f* | c* |
i4 | i4 | i4 | i4 | i8 | f* | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i4 | f* | c* |
i8 | i8 | i8 | i8 | i8 | f* | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i8 | f* | c* |
bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | f4 | f4 | f8 | c8 | c16 | bf | bf | c8 |
f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f4 | f2 | f4 | f8 | c8 | c16 | f2 | f2 | c8 |
f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f8 | c8 | c16 | f4 | f4 | c8 |
f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | c16 | c16 | f8 | f8 | c16 |
c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c16 | c8 | c16 | c8 | c8 | c8 |
c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 | c16 |
i* | i* | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c8 | c16 | i* | f* | c* |
f* | f* | f* | f* | f* | f* | f* | f* | f* | f* | bf | f2 | f4 | f8 | c8 | c16 | f* | f* | c* |
c* | c* | c* | c* | c* | c* | c* | c* | c* | c* | c8 | c8 | c8 | c16 | c8 | c16 | c* | c* | c* |
Jax 的型別提升規則與 NumPy 的規則不同,如上述表格中以綠色背景突出顯示的儲存格所示,NumPy 的規則由 numpy.promote_types()
給出。主要有三類差異
當將弱型別值與相同類別的強型別 JAX 值進行提升時,JAX 始終偏好 JAX 值的精確度。例如,
jnp.int16(1) + 1
將返回int16
,而不是像 NumPy 中那樣提升為int64
。請注意,這僅適用於 Python 純量值;如果常數是 NumPy 陣列,則上述格狀結構用於型別提升。例如,jnp.int16(1) + np.array(1)
將返回int64
。當將整數或布林型別與浮點型別或複數型別進行提升時,JAX 始終偏好浮點型別或複數型別的型別。
JAX 支援 bfloat16 非標準 16 位元浮點型別 (
jax.numpy.bfloat16
),這對於神經網路訓練很有用。唯一值得注意的提升行為是關於 IEEE-754float16
,bfloat16
與之提升為float32
。
NumPy 和 JAX 之間的差異是由於加速器裝置(例如 GPU 和 TPU)使用 64 位元浮點型別會產生顯著的效能損失 (GPU) 或根本不支援 64 位元浮點型別 (TPU) 所驅動的。經典 NumPy 的提升規則太過願意過度提升到 64 位元型別,這對於旨在在加速器上運行的系統來說是有問題的。
JAX 使用的浮點提升規則更適合現代加速器裝置,並且在提升浮點型別方面較不激進。JAX 用於浮點型別的提升規則與 PyTorch 使用的規則相似。
Python 運算子調度的影響#
請記住,像 + 這樣的 Python 運算子將根據正在相加的兩個值的 Python 型別進行調度。這表示,例如,np.int16(1) + 1
將使用 NumPy 規則進行提升,而 jnp.int16(1) + 1
將使用 JAX 規則進行提升。當兩種提升類型組合在一起時,這可能會導致潛在的混淆性非關聯提升語意;例如,使用 np.int16(1) + 1 + jnp.int16(1)
。
JAX 中的弱型別值#
在大多數情況下,JAX 中的弱型別值可以被認為具有與 Python 純量等效的提升行為,例如以下程式碼中的整數純量 2
>>> x = jnp.arange(5, dtype='int8')
>>> 2 * x
Array([0, 2, 4, 6, 8], dtype=int8)
JAX 的弱型別框架旨在防止 JAX 值與沒有明確使用者指定型別的值(例如 Python 純量字面值)之間的二元運算中發生不必要的型別提升。例如,如果 2
未被視為弱型別,則上述表達式將導致隱式型別提升
>>> jnp.int32(2) * x
Array([0, 2, 4, 6, 8], dtype=int32)
在 JAX 中使用時,Python 純量有時會提升為 DeviceArray
物件,例如在 JIT 編譯期間。為了在這種情況下保持所需的提升語意,DeviceArray
物件攜帶一個 weak_type
旗標,可以在陣列的字串表示中看到
>>> jnp.asarray(2)
Array(2, dtype=int32, weak_type=True)
如果明確指定了 dtype
,則會產生標準的強型別陣列值
>>> jnp.asarray(2, dtype='int32')
Array(2, dtype=int32)
嚴格 dtype 提升#
在某些情況下,停用隱式型別提升行為可能很有用,而是要求所有提升都必須是明確的。這可以在 JAX 中透過將 jax_numpy_dtype_promotion
旗標設定為 'strict'
來完成。在本地,可以使用 context manager 來完成
>>> x = jnp.float32(1)
>>> y = jnp.int32(1)
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + y
...
Traceback (most recent call last):
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.
為了方便起見,嚴格提升模式仍然允許安全的弱型別提升,因此您仍然可以撰寫混合 JAX 陣列和 Python 純量的程式碼
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + 1
>>> print(z)
2.0
如果您希望全域設定組態,可以使用標準組態更新來完成
jax.config.update('jax_numpy_dtype_promotion', 'strict')
若要還原預設標準型別提升,請將此組態設定為 'standard'
jax.config.update('jax_numpy_dtype_promotion', 'standard')