jax.numpy.promote_types#
- jax.numpy.promote_types(a, b)[原始碼]#
傳回二元運算應將其引數轉換成的型別。
JAX 版本的
numpy.promote_types()
實作。關於 JAX 型別提升語意的詳細資訊,請參閱型別提升語意。- 參數:
a (DTypeLike) –
numpy.dtype
或型別指定符。b (DTypeLike) –
numpy.dtype
或型別指定符。
- 傳回值:
一個
numpy.dtype
物件。- 回傳型別:
DType
範例
型別指定符可以是字串、dtypes 或純量型別,而傳回值永遠是 dtype
>>> jnp.promote_types('int32', 'float32') # strings dtype('float32') >>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32')) # dtypes dtype('float32') >>> jnp.promote_types(jnp.int32, jnp.float32) # scalar types dtype('float32')
內建純量型別 (
int
、float
或complex
) 被視為弱型別,且不會更改強型別對應項的位元寬度 (請參閱型別提升語意中的討論)>>> jnp.promote_types('uint8', int) dtype('uint8') >>> jnp.promote_types('float16', float) dtype('float16')
這與此函式的 NumPy 版本不同,NumPy 版本將內建純量型別視為等同於 64 位元型別
>>> import numpy >>> numpy.promote_types('uint8', int) dtype('int64') >>> numpy.promote_types('float16', float) dtype('float64')