jax.numpy.vecdot#

jax.numpy.vecdot(x1, x2, /, *, axis=-1, precision=None, preferred_element_type=None)[原始碼]#

執行兩個批次向量的共軛乘法。

numpy.vecdot() 的 JAX 實作。

參數::
  • a – 左側陣列。

  • b – 右側陣列。 b[axis] 的大小必須與 a[axis] 的大小相符,且其餘維度必須是可廣播相容的。

  • axis (int) – 計算點積的軸 (預設:-1)

  • precision (PrecisionLike | None) – None (預設),表示後端的預設精度、Precision 列舉值 (Precision.DEFAULTPrecision.HIGHPrecision.HIGHEST) 或兩個此類值的元組,表示 ab 的精度。

  • preferred_element_type (DTypeLike | None | None) – None (預設),表示輸入型別的預設累加型別,或是一種資料型別,表示將結果累加至該資料型別並傳回具有該資料型別的結果。

  • x1 (ArrayLike)

  • x2 (ArrayLike)

返回::

陣列,包含 ab 沿著 axis 的共軛點積。非收縮維度會廣播在一起。

返回型別::

Array

另請參閱

範例

兩個一維陣列的向量共軛點積

>>> a = jnp.array([1j, 2j, 3j])
>>> b = jnp.array([4., 5., 6.])
>>> jnp.linalg.vecdot(a, b)
Array(0.-32.j, dtype=complex64)

兩個二維陣列的批次向量點積

>>> a = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> b = jnp.array([[2, 3, 4]])
>>> jnp.linalg.vecdot(a, b, axis=-1)
Array([20, 47], dtype=int32)