jax.numpy.matmul#
- jax.numpy.matmul(a, b, *, precision=None, preferred_element_type=None)[原始碼]#
執行矩陣乘法。
JAX 實作的
numpy.matmul()
。- 參數:
a (ArrayLike) – 第一個輸入陣列,形狀為
(N,)
或(..., K, N)
。b (ArrayLike) – 第二個輸入陣列。必須具有形狀
(N,)
或(..., N, M)
。在多維情況下,前導維度必須與a
的前導維度廣播相容。precision (PrecisionLike) –
None
(預設值),表示後端的預設精確度;Precision
列舉值 (Precision.DEFAULT
、Precision.HIGH
或Precision.HIGHEST
);或一組包含兩個此類值的元組,指示a
和b
的精確度。preferred_element_type (DTypeLike | None) –
None
(預設值),表示輸入類型的預設累積類型;或資料類型,指示將結果累積到該資料類型並傳回具有該資料類型的結果。
- 傳回值:
包含輸入的矩陣乘積的陣列。如果
b.ndim == 1
,則形狀為a.shape[:-1]
,否則形狀為(..., K, M)
,其中a
和b
的前導維度會一起廣播。- 傳回類型:
另請參閱
jax.numpy.linalg.vecdot()
:批次向量乘積。jax.numpy.linalg.tensordot()
:批次張量乘積。jax.lax.dot_general()
:一般 N 維批次點積。
範例
向量點積
>>> a = jnp.array([1, 2, 3]) >>> b = jnp.array([4, 5, 6]) >>> jnp.matmul(a, b) Array(32, dtype=int32)
矩陣點積
>>> a = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> b = jnp.array([[1, 2], ... [3, 4], ... [5, 6]]) >>> jnp.matmul(a, b) Array([[22, 28], [49, 64]], dtype=int32)
為了方便起見,在所有情況下,您都可以使用
@
運算子執行相同的計算>>> a @ b Array([[22, 28], [49, 64]], dtype=int32)