jax.numpy.vecmat#
- jax.numpy.vecmat(x1, x2, /)[source]#
批次共軛向量-矩陣乘積。
numpy.vecmat()
的 JAX 實作。- 參數:
x1 (ArrayLike) – 形狀為
(..., M)
的陣列。x2 (ArrayLike) – 形狀為
(..., M, N)
的陣列。前導維度必須與x1
的前導維度廣播相容。
- 返回值:
一個形狀為
(..., N)
的陣列,包含批次共軛向量-矩陣乘積。- 返回類型:
另請參閱
jax.numpy.linalg.vecdot()
:批次向量乘積。jax.numpy.matvec()
:矩陣-向量乘積。jax.numpy.matmul()
:一般矩陣乘法。
範例
簡單向量-矩陣乘積
>>> x1 = jnp.array([[1, 2, 3]]) >>> x2 = jnp.array([[4, 5], ... [6, 7], ... [8, 9]]) >>> jnp.vecmat(x1, x2) Array([[40, 46]], dtype=int32)
批次向量-矩陣乘積
>>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.vecmat(x1, x2) Array([[ 40, 46], [ 94, 109]], dtype=int32)