jax.lax.dot_general#

jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None, out_sharding=None)[原始碼]#

通用的點積/縮併運算子。

包裝 XLA 的 DotGeneral 運算子。

dot_general 的語意很複雜,但大多數使用者應該不需要直接使用它。相反地,您可以使用更高等級的函式,例如 jax.numpy.dot()jax.numpy.matmul()jax.numpy.tensordot()jax.numpy.einsum() 和其他函式,這些函式將在底層建構對 dot_general 的適當呼叫。如果您真的想了解 dot_general 本身,我們建議您閱讀 XLA 的 DotGeneral 運算子文件。

參數:
  • lhs (ArrayLike) – 一個陣列

  • rhs (ArrayLike) – 一個陣列

  • dimension_numbers (DotDimensionNumbers) – 一個元組的元組,包含整數序列,形式為 ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

  • precision (PrecisionLike | None) –

    選用。此參數控制計算的數值,可以是以下其中之一

  • preferred_element_type (DTypeLike | None | None) – 選用。此參數控制點積輸出的資料型別。預設情況下,此運算的輸出元素型別將根據通常的型別提升規則,與 lhsrhs 輸入元素型別相符。將 preferred_element_type 設定為特定的 dtype 將表示運算傳回該元素型別。當 precision 不是 DotAlgorithmDotAlgorithmPreset 時,preferred_element_type 提供編譯器提示,以使用此資料型別累積點積。

傳回:

一個陣列,其第一個維度是(共用的)批次維度,接著是 lhs 非縮併/非批次維度,最後是 rhs 非縮併/非批次維度。

傳回型別:

Array