jax.numpy.dot#

jax.numpy.dot(a, b, *, precision=None, preferred_element_type=None)[原始碼]#

計算兩個陣列的點積。

numpy.dot() 的 JAX 實作。

這與 jax.numpy.matmul() 在兩個方面有所不同

  • 如果 ab 其中之一是純量,則 dot 的結果等同於 jax.numpy.multiply(),而 matmul 的結果會是錯誤。

  • 如果 ab 具有超過 2 個維度,則批次索引會堆疊而不是廣播。

參數:
  • a (ArrayLike) – 第一個輸入陣列,形狀為 (..., N)

  • b (ArrayLike) – 第二個輸入陣列。必須具有形狀 (N,)(..., N, M)。在多維情況下,前導維度必須與 a 的前導維度廣播相容。

  • precision (PrecisionLike) – 要么 None (預設),表示後端的預設精度,要么是 Precision 列舉值 (Precision.DEFAULT, Precision.HIGHPrecision.HIGHEST),要么是由兩個此類值組成的元組,表示 ab 的精度。

  • preferred_element_type (DTypeLike | None) – 要么 None (預設),表示輸入型別的預設累加型別,要么是資料型別,表示將結果累加到該資料型別並傳回具有該資料型別的結果。

傳回:

包含輸入點積的陣列,其中 ab 的批次維度堆疊而不是廣播。

回傳類型:

Array

另請參閱

範例

對於純量輸入,dot 計算元素級乘積

>>> x = jnp.array([1, 2, 3])
>>> jnp.dot(x, 2)
Array([2, 4, 6], dtype=int32)

對於向量或矩陣輸入,dot 計算向量或矩陣乘積

>>> M = jnp.array([[2, 3, 4],
...                [5, 6, 7],
...                [8, 9, 0]])
>>> jnp.dot(M, x)
Array([20, 38, 26], dtype=int32)
>>> jnp.dot(M, M)
Array([[ 51,  60,  29],
       [ 96, 114,  62],
       [ 61,  78,  95]], dtype=int32)

對於更高維度的矩陣乘積,批次維度會堆疊,而在 matmul() 中,它們會廣播。例如

>>> a = jnp.zeros((3, 2, 4))
>>> b = jnp.zeros((3, 4, 1))
>>> jnp.dot(a, b).shape
(3, 2, 3, 1)
>>> jnp.matmul(a, b).shape
(3, 2, 1)