jax.numpy.matvec#

jax.numpy.matvec(x1, x2, /)[原始碼]#

批次矩陣向量乘積。

numpy.matvec() 的 JAX 實作。

參數:
  • x1 (ArrayLike) – 形狀為 (..., M, N) 的陣列

  • x2 (ArrayLike) – 形狀為 (..., N) 的陣列。前導維度必須與 x1 的前導維度廣播相容。

回傳:

形狀為 (..., M),包含批次矩陣向量乘積的陣列。

回傳型別:

Array

參見

範例

簡單矩陣向量乘積

>>> x1 = jnp.array([[1, 2, 3],
...                 [4, 5, 6]])
>>> x2 = jnp.array([7, 8, 9])
>>> jnp.matvec(x1, x2)
Array([ 50, 122], dtype=int32)

批次矩陣向量乘積

>>> x2 = jnp.array([[7, 8, 9],
...                 [5, 6, 7]])
>>> jnp.matvec(x1, x2)
Array([[ 50, 122],
       [ 38,  92]], dtype=int32)