jax.numpy.inner#
- jax.numpy.inner(a, b, *, precision=None, preferred_element_type=None)[原始碼]#
計算兩個陣列的內積。
JAX 實作的
numpy.inner()
。與
jax.numpy.matmul()
或jax.numpy.dot()
不同,這始終沿著每個輸入的最後一個維度執行收縮。- 參數:
a (ArrayLike) – 形狀為
(..., N)
的陣列b (ArrayLike) – 形狀為
(..., N)
的陣列precision (PrecisionLike) –
None
(預設值),表示後端的預設精度,Precision
列舉值 (Precision.DEFAULT
、Precision.HIGH
或Precision.HIGHEST
) 或兩個此類值的元組,表示a
和b
的精度。preferred_element_type (DType | None) –
None
(預設值),表示輸入類型的預設累積類型,或資料類型,表示將結果累積到該資料類型並傳回結果。
- 傳回:
形狀為
(*a.shape[:-1], *b.shape[:-1])
的陣列,包含輸入的批次向量積。- 傳回類型:
另請參閱
jax.numpy.vecdot()
:沿指定軸的共軛乘法。jax.numpy.tensordot()
:一般張量乘法。jax.numpy.matmul()
:一般批次矩陣和向量乘法。
範例
對於 1D 輸入,這會實作標準(非共軛)向量乘法
>>> a = jnp.array([1j, 3j, 4j]) >>> b = jnp.array([4., 2., 5.]) >>> jnp.inner(a, b) Array(0.+30.j, dtype=complex64)
對於多維輸入,批次維度會堆疊而不是廣播
>>> a = jnp.ones((2, 3)) >>> b = jnp.ones((5, 3)) >>> jnp.inner(a, b).shape (2, 5)