jax.numpy.result_type#
- jax.numpy.result_type(*args)[原始碼]#
傳回將 JAX 提升規則應用於輸入的結果。
JAX 實作的
numpy.result_type()
。JAX 的 dtype 提升行為在類型提升語意中描述。
- 參數:
args (Any) – 一個或多個陣列或類似 dtype 的物件。
- 傳回值:
一個
numpy.dtype
實例,表示輸入的類型提升結果。- 傳回類型:
DType
範例
輸入可以是 dtype 指定符
>>> jnp.result_type('int32', 'float32') dtype('float32') >>> jnp.result_type(np.uint16, np.dtype('int32')) dtype('int32')
輸入也可能是純量或陣列
>>> jnp.result_type(1.0, jnp.bfloat16(2)) dtype(bfloat16) >>> jnp.result_type(jnp.arange(4), jnp.zeros(4)) dtype('float32')
請注意,結果類型將根據
jax_enable_x64
配置旗標的狀態進行正規化,這表示 64 位元類型可能會降級為 32 位元>>> jnp.result_type('float64') dtype('float32')
有關 64 位元值的詳細資訊,請參閱Sharp bits - double precision