jax.numpy.vecmat#

jax.numpy.vecmat(x1, x2, /)[source]#

批次共軛向量-矩陣乘積。

numpy.vecmat() 的 JAX 實作。

參數:
  • x1 (ArrayLike) – 形狀為 (..., M) 的陣列。

  • x2 (ArrayLike) – 形狀為 (..., M, N) 的陣列。前導維度必須與 x1 的前導維度廣播相容。

返回值:

一個形狀為 (..., N) 的陣列,包含批次共軛向量-矩陣乘積。

返回類型:

陣列

另請參閱

範例

簡單向量-矩陣乘積

>>> x1 = jnp.array([[1, 2, 3]])
>>> x2 = jnp.array([[4, 5],
...                 [6, 7],
...                 [8, 9]])
>>> jnp.vecmat(x1, x2)
Array([[40, 46]], dtype=int32)

批次向量-矩陣乘積

>>> x1 = jnp.array([[1, 2, 3],
...                 [4, 5, 6]])
>>> jnp.vecmat(x1, x2)
Array([[ 40,  46],
       [ 94, 109]], dtype=int32)