jax.numpy.result_type#

jax.numpy.result_type(*args)[原始碼]#

傳回將 JAX 提升規則應用於輸入的結果。

JAX 實作的 numpy.result_type()

JAX 的 dtype 提升行為在類型提升語意中描述。

參數:

args (Any) – 一個或多個陣列或類似 dtype 的物件。

傳回值:

一個 numpy.dtype 實例,表示輸入的類型提升結果。

傳回類型:

DType

範例

輸入可以是 dtype 指定符

>>> jnp.result_type('int32', 'float32')
dtype('float32')
>>> jnp.result_type(np.uint16, np.dtype('int32'))
dtype('int32')

輸入也可能是純量或陣列

>>> jnp.result_type(1.0, jnp.bfloat16(2))
dtype(bfloat16)
>>> jnp.result_type(jnp.arange(4), jnp.zeros(4))
dtype('float32')

請注意,結果類型將根據 jax_enable_x64 配置旗標的狀態進行正規化,這表示 64 位元類型可能會降級為 32 位元

>>> jnp.result_type('float64')
dtype('float32')

有關 64 位元值的詳細資訊,請參閱Sharp bits - double precision