jax.lax.dot#

jax.lax.dot(lhs, rhs, precision=None, preferred_element_type=None)[來源]#

向量/向量、矩陣/向量和矩陣/矩陣乘法。

包裝 XLA 的 Dot 運算子。

如需更一般的收縮,請參閱 jax.lax.dot_general() 運算子。

參數:
  • lhs (Array) – 維度為 1 或 2 的陣列。

  • rhs (Array) – 維度為 1 或 2 的陣列。

  • precision (PrecisionLike | None) –

    選用。此參數控制計算的數值,可以是下列其中一項

  • preferred_element_type (DTypeLike | None | None) – 選用。此參數控制點積輸出的資料類型。預設情況下,此運算的輸出元素類型將符合常用類型提升規則下的 lhsrhs 輸入元素類型。將 preferred_element_type 設定為特定 dtype 將表示運算傳回該元素類型。當 precision 不是 DotAlgorithmDotAlgorithmPreset 時,preferred_element_type 會向編譯器提供提示,以使用此資料類型累積點積。

傳回:

包含乘積的陣列。

傳回類型:

Array