jax.numpy.matvec#
- jax.numpy.matvec(x1, x2, /)[原始碼]#
批次矩陣向量乘積。
numpy.matvec()
的 JAX 實作。- 參數:
x1 (ArrayLike) – 形狀為
(..., M, N)
的陣列x2 (ArrayLike) – 形狀為
(..., N)
的陣列。前導維度必須與x1
的前導維度廣播相容。
- 回傳:
形狀為
(..., M)
,包含批次矩陣向量乘積的陣列。- 回傳型別:
參見
jax.numpy.linalg.vecdot()
:批次向量乘積。jax.numpy.vecmat()
:向量矩陣乘積。jax.numpy.matmul()
:一般矩陣乘法。
範例
簡單矩陣向量乘積
>>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> x2 = jnp.array([7, 8, 9]) >>> jnp.matvec(x1, x2) Array([ 50, 122], dtype=int32)
批次矩陣向量乘積
>>> x2 = jnp.array([[7, 8, 9], ... [5, 6, 7]]) >>> jnp.matvec(x1, x2) Array([[ 50, 122], [ 38, 92]], dtype=int32)