jax.numpy.dot#
- jax.numpy.dot(a, b, *, precision=None, preferred_element_type=None)[原始碼]#
計算兩個陣列的點積。
numpy.dot()
的 JAX 實作。這與
jax.numpy.matmul()
在兩個方面有所不同如果
a
或b
其中之一是純量,則dot
的結果等同於jax.numpy.multiply()
,而matmul
的結果會是錯誤。如果
a
和b
具有超過 2 個維度,則批次索引會堆疊而不是廣播。
- 參數:
a (ArrayLike) – 第一個輸入陣列,形狀為
(..., N)
。b (ArrayLike) – 第二個輸入陣列。必須具有形狀
(N,)
或(..., N, M)
。在多維情況下,前導維度必須與a
的前導維度廣播相容。precision (PrecisionLike) – 要么
None
(預設),表示後端的預設精度,要么是Precision
列舉值 (Precision.DEFAULT
,Precision.HIGH
或Precision.HIGHEST
),要么是由兩個此類值組成的元組,表示a
和b
的精度。preferred_element_type (DTypeLike | None) – 要么
None
(預設),表示輸入型別的預設累加型別,要么是資料型別,表示將結果累加到該資料型別並傳回具有該資料型別的結果。
- 傳回:
包含輸入點積的陣列,其中
a
和b
的批次維度堆疊而不是廣播。- 回傳類型:
另請參閱
jax.numpy.matmul()
:廣播批次矩陣乘法。jax.lax.dot_general()
:一般批次矩陣乘法。
範例
對於純量輸入,
dot
計算元素級乘積>>> x = jnp.array([1, 2, 3]) >>> jnp.dot(x, 2) Array([2, 4, 6], dtype=int32)
對於向量或矩陣輸入,
dot
計算向量或矩陣乘積>>> M = jnp.array([[2, 3, 4], ... [5, 6, 7], ... [8, 9, 0]]) >>> jnp.dot(M, x) Array([20, 38, 26], dtype=int32) >>> jnp.dot(M, M) Array([[ 51, 60, 29], [ 96, 114, 62], [ 61, 78, 95]], dtype=int32)
對於更高維度的矩陣乘積,批次維度會堆疊,而在
matmul()
中,它們會廣播。例如>>> a = jnp.zeros((3, 2, 4)) >>> b = jnp.zeros((3, 4, 1)) >>> jnp.dot(a, b).shape (3, 2, 3, 1) >>> jnp.matmul(a, b).shape (3, 2, 1)