jax.lax.dot#
- jax.lax.dot(lhs, rhs, precision=None, preferred_element_type=None)[來源]#
向量/向量、矩陣/向量和矩陣/矩陣乘法。
包裝 XLA 的 Dot 運算子。
如需更一般的收縮,請參閱
jax.lax.dot_general()
運算子。- 參數:
lhs (Array) – 維度為 1 或 2 的陣列。
rhs (Array) – 維度為 1 或 2 的陣列。
precision (PrecisionLike | None) –
選用。此參數控制計算的數值,可以是下列其中一項
None
,表示目前後端的預設精確度,DotAlgorithm
或DotAlgorithmPreset
,表示必須用於累積點積的演算法。
preferred_element_type (DTypeLike | None | None) – 選用。此參數控制點積輸出的資料類型。預設情況下,此運算的輸出元素類型將符合常用類型提升規則下的
lhs
和rhs
輸入元素類型。將preferred_element_type
設定為特定dtype
將表示運算傳回該元素類型。當precision
不是DotAlgorithm
或DotAlgorithmPreset
時,preferred_element_type
會向編譯器提供提示,以使用此資料類型累積點積。
- 傳回:
包含乘積的陣列。
- 傳回類型: