型別提升語意#

本文檔描述 JAX 的型別提升規則,即 jax.numpy.promote_types() 對於每對型別的結果。關於設計以下描述內容的考量背景,請參閱 JAX 型別提升語意設計

JAX 的型別提升行為透過以下型別提升格狀結構決定

_images/type_lattice.svg

例如

  • b1 表示 np.bool_

  • i2 表示 np.int16

  • u4 表示 np.uint32

  • bf 表示 np.bfloat16

  • f2 表示 np.float16

  • c8 表示 np.complex64

  • i* 表示 Python int 或弱型別 int

  • f* 表示 Python float 或弱型別 float,以及

  • c* 表示 Python complex 或弱型別 complex

(關於弱型別的更多資訊,請參閱下方的JAX 中的弱型別值)。

任何兩種型別之間的提升由它們在此格狀結構上的 join 給出,這會產生以下二元提升表

b1u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
b1b1u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
u1u1u1u2u4u8i2i2i4i8bff2f4f8c8c16u1f*c*
u2u2u2u2u4u8i4i4i4i8bff2f4f8c8c16u2f*c*
u4u4u4u4u4u8i8i8i8i8bff2f4f8c8c16u4f*c*
u8u8u8u8u8u8f*f*f*f*bff2f4f8c8c16u8f*c*
i1i1i2i4i8f*i1i2i4i8bff2f4f8c8c16i1f*c*
i2i2i2i4i8f*i2i2i4i8bff2f4f8c8c16i2f*c*
i4i4i4i4i8f*i4i4i4i8bff2f4f8c8c16i4f*c*
i8i8i8i8i8f*i8i8i8i8bff2f4f8c8c16i8f*c*
bfbfbfbfbfbfbfbfbfbfbff4f4f8c8c16bfbfc8
f2f2f2f2f2f2f2f2f2f2f4f2f4f8c8c16f2f2c8
f4f4f4f4f4f4f4f4f4f4f4f4f4f8c8c16f4f4c8
f8f8f8f8f8f8f8f8f8f8f8f8f8f8c16c16f8f8c16
c8c8c8c8c8c8c8c8c8c8c8c8c8c16c8c16c8c8c8
c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16
i*i*u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
f*f*f*f*f*f*f*f*f*f*bff2f4f8c8c16f*f*c*
c*c*c*c*c*c*c*c*c*c*c8c8c8c16c8c16c*c*c*

Jax 的型別提升規則與 NumPy 的規則不同,如上述表格中以綠色背景突出顯示的儲存格所示,NumPy 的規則由 numpy.promote_types() 給出。主要有三類差異

  • 當將弱型別值與相同類別的強型別 JAX 值進行提升時,JAX 始終偏好 JAX 值的精確度。例如,jnp.int16(1) + 1 將返回 int16,而不是像 NumPy 中那樣提升為 int64。請注意,這僅適用於 Python 純量值;如果常數是 NumPy 陣列,則上述格狀結構用於型別提升。例如,jnp.int16(1) + np.array(1) 將返回 int64

  • 當將整數或布林型別與浮點型別或複數型別進行提升時,JAX 始終偏好浮點型別或複數型別的型別。

  • JAX 支援 bfloat16 非標準 16 位元浮點型別 (jax.numpy.bfloat16),這對於神經網路訓練很有用。唯一值得注意的提升行為是關於 IEEE-754 float16bfloat16 與之提升為 float32

NumPy 和 JAX 之間的差異是由於加速器裝置(例如 GPU 和 TPU)使用 64 位元浮點型別會產生顯著的效能損失 (GPU) 或根本不支援 64 位元浮點型別 (TPU) 所驅動的。經典 NumPy 的提升規則太過願意過度提升到 64 位元型別,這對於旨在在加速器上運行的系統來說是有問題的。

JAX 使用的浮點提升規則更適合現代加速器裝置,並且在提升浮點型別方面較不激進。JAX 用於浮點型別的提升規則與 PyTorch 使用的規則相似。

Python 運算子調度的影響#

請記住,像 + 這樣的 Python 運算子將根據正在相加的兩個值的 Python 型別進行調度。這表示,例如,np.int16(1) + 1 將使用 NumPy 規則進行提升,而 jnp.int16(1) + 1 將使用 JAX 規則進行提升。當兩種提升類型組合在一起時,這可能會導致潛在的混淆性非關聯提升語意;例如,使用 np.int16(1) + 1 + jnp.int16(1)

JAX 中的弱型別值#

在大多數情況下,JAX 中的弱型別值可以被認為具有與 Python 純量等效的提升行為,例如以下程式碼中的整數純量 2

>>> x = jnp.arange(5, dtype='int8')
>>> 2 * x
Array([0, 2, 4, 6, 8], dtype=int8)

JAX 的弱型別框架旨在防止 JAX 值與沒有明確使用者指定型別的值(例如 Python 純量字面值)之間的二元運算中發生不必要的型別提升。例如,如果 2 未被視為弱型別,則上述表達式將導致隱式型別提升

>>> jnp.int32(2) * x
Array([0, 2, 4, 6, 8], dtype=int32)

在 JAX 中使用時,Python 純量有時會提升為 DeviceArray 物件,例如在 JIT 編譯期間。為了在這種情況下保持所需的提升語意,DeviceArray 物件攜帶一個 weak_type 旗標,可以在陣列的字串表示中看到

>>> jnp.asarray(2)
Array(2, dtype=int32, weak_type=True)

如果明確指定了 dtype,則會產生標準的強型別陣列值

>>> jnp.asarray(2, dtype='int32')
Array(2, dtype=int32)

嚴格 dtype 提升#

在某些情況下,停用隱式型別提升行為可能很有用,而是要求所有提升都必須是明確的。這可以在 JAX 中透過將 jax_numpy_dtype_promotion 旗標設定為 'strict' 來完成。在本地,可以使用 context manager 來完成

>>> x = jnp.float32(1)
>>> y = jnp.int32(1)
>>> with jax.numpy_dtype_promotion('strict'):
...   z = x + y  
...
Traceback (most recent call last):
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.

為了方便起見,嚴格提升模式仍然允許安全的弱型別提升,因此您仍然可以撰寫混合 JAX 陣列和 Python 純量的程式碼

>>> with jax.numpy_dtype_promotion('strict'):
...   z = x + 1
>>> print(z)
2.0

如果您希望全域設定組態,可以使用標準組態更新來完成

jax.config.update('jax_numpy_dtype_promotion', 'strict')

若要還原預設標準型別提升,請將此組態設定為 'standard'

jax.config.update('jax_numpy_dtype_promotion', 'standard')