jax.lax.bitcast_convert_type#
- jax.lax.bitcast_convert_type(operand, new_dtype)[原始碼]#
逐元素位元轉換。
包裝 XLA 的 BitcastConvertType 運算子,該運算子執行從一種型別到另一種型別的位元轉換。
輸出形狀取決於輸入和輸出 dtype 的大小,邏輯如下
if new_dtype.itemsize == operand.dtype.itemsize: output_shape = operand.shape if new_dtype.itemsize < operand.dtype.itemsize: output_shape = (*operand.shape, operand.dtype.itemsize // new_dtype.itemsize) if new_dtype.itemsize > operand.dtype.itemsize: assert operand.shape[-1] * operand.dtype.itemsize == new_dtype.itemsize output_shape = operand.shape[:-1]
- 參數:
operand (ArrayLike) – 要轉換的陣列或純量值
new_dtype (DTypeLike) – 新型別。應為 NumPy 型別。
- 傳回:
形狀為 output_shape(見上文)且型別為 new_dtype 的陣列,由與 operand 相同的位元建構而成。
- 傳回型別: