jax.numpy.promote_types#

jax.numpy.promote_types(a, b)[原始碼]#

傳回二元運算應將其引數轉換成的型別。

JAX 版本的 numpy.promote_types() 實作。關於 JAX 型別提升語意的詳細資訊,請參閱型別提升語意

參數:
傳回值:

一個 numpy.dtype 物件。

回傳型別:

DType

範例

型別指定符可以是字串、dtypes 或純量型別,而傳回值永遠是 dtype

>>> jnp.promote_types('int32', 'float32')  # strings
dtype('float32')
>>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32'))  # dtypes
dtype('float32')
>>> jnp.promote_types(jnp.int32, jnp.float32)  # scalar types
dtype('float32')

內建純量型別 ( intfloatcomplex ) 被視為弱型別,且不會更改強型別對應項的位元寬度 (請參閱型別提升語意中的討論)

>>> jnp.promote_types('uint8', int)
dtype('uint8')
>>> jnp.promote_types('float16', float)
dtype('float16')

這與此函式的 NumPy 版本不同,NumPy 版本將內建純量型別視為等同於 64 位元型別

>>> import numpy
>>> numpy.promote_types('uint8', int)
dtype('int64')
>>> numpy.promote_types('float16', float)
dtype('float64')