jax.lax.dot_general#
- jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None, out_sharding=None)[原始碼]#
通用的點積/縮併運算子。
包裝 XLA 的 DotGeneral 運算子。
dot_general
的語意很複雜,但大多數使用者應該不需要直接使用它。相反地,您可以使用更高等級的函式,例如jax.numpy.dot()
、jax.numpy.matmul()
、jax.numpy.tensordot()
、jax.numpy.einsum()
和其他函式,這些函式將在底層建構對dot_general
的適當呼叫。如果您真的想了解dot_general
本身,我們建議您閱讀 XLA 的 DotGeneral 運算子文件。- 參數:
lhs (ArrayLike) – 一個陣列
rhs (ArrayLike) – 一個陣列
dimension_numbers (DotDimensionNumbers) – 一個元組的元組,包含整數序列,形式為
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))
precision (PrecisionLike | None) –
選用。此參數控制計算的數值,可以是以下其中之一
None
,表示目前後端的預設精度,一個
DotAlgorithm
或一個DotAlgorithmPreset
,表示必須用於累積點積的演算法。
preferred_element_type (DTypeLike | None | None) – 選用。此參數控制點積輸出的資料型別。預設情況下,此運算的輸出元素型別將根據通常的型別提升規則,與
lhs
和rhs
輸入元素型別相符。將preferred_element_type
設定為特定的dtype
將表示運算傳回該元素型別。當precision
不是DotAlgorithm
或DotAlgorithmPreset
時,preferred_element_type
提供編譯器提示,以使用此資料型別累積點積。
- 傳回:
一個陣列,其第一個維度是(共用的)批次維度,接著是
lhs
非縮併/非批次維度,最後是rhs
非縮併/非批次維度。- 傳回型別: