jax.numpy.inner#

jax.numpy.inner(a, b, *, precision=None, preferred_element_type=None)[原始碼]#

計算兩個陣列的內積。

JAX 實作的 numpy.inner()

jax.numpy.matmul()jax.numpy.dot() 不同,這始終沿著每個輸入的最後一個維度執行收縮。

參數:
  • a (ArrayLike) – 形狀為 (..., N) 的陣列

  • b (ArrayLike) – 形狀為 (..., N) 的陣列

  • precision (PrecisionLike) – None (預設值),表示後端的預設精度,Precision 列舉值 (Precision.DEFAULTPrecision.HIGHPrecision.HIGHEST) 或兩個此類值的元組,表示 ab 的精度。

  • preferred_element_type (DType | None) – None (預設值),表示輸入類型的預設累積類型,或資料類型,表示將結果累積到該資料類型並傳回結果。

傳回:

形狀為 (*a.shape[:-1], *b.shape[:-1]) 的陣列,包含輸入的批次向量積。

傳回類型:

Array

另請參閱

範例

對於 1D 輸入,這會實作標準(非共軛)向量乘法

>>> a = jnp.array([1j, 3j, 4j])
>>> b = jnp.array([4., 2., 5.])
>>> jnp.inner(a, b)
Array(0.+30.j, dtype=complex64)

對於多維輸入,批次維度會堆疊而不是廣播

>>> a = jnp.ones((2, 3))
>>> b = jnp.ones((5, 3))
>>> jnp.inner(a, b).shape
(2, 5)