jax.numpy.linalg.tensordot#
- jax.numpy.linalg.tensordot(x1, x2, /, *, axes=2, precision=None, preferred_element_type=None)[原始碼]#
計算兩個 N 維陣列的張量點積。
JAX 實作的
numpy.linalg.tensordot()
。- 參數:
x1 (ArrayLike) – N 維陣列
x2 (ArrayLike) – M 維陣列
axes (int | tuple[Sequence[int], Sequence[int]]) – 整數或整數序列的元組。如果是一個整數 k,則對
x1
的最後 k 個軸和x2
的前 k 個軸依序求和。如果是一個元組,則axes[0]
指定x1
的軸,而axes[1]
指定x2
的軸。precision (PrecisionLike | None) – 可以是
None
(預設值),表示後端的預設精確度;Precision
列舉值 (Precision.DEFAULT
、Precision.HIGH
或Precision.HIGHEST
);或是兩個這類值的元組,表示x1
和x2
的精確度。preferred_element_type (DTypeLike | None | None) – 可以是
None
(預設值),表示輸入類型的預設累積類型;或是一個資料類型,表示將結果累積到該資料類型並傳回具有該資料類型的結果。
- 傳回值:
包含輸入張量點積的陣列
- 傳回類型:
另請參閱
jax.numpy.tensordot()
:在jax.numpy
命名空間中等效的 API。jax.numpy.einsum()
:用於更通用張量收縮的 NumPy API。jax.lax.dot_general()
:用於更通用張量收縮的 XLA API。
範例
>>> x1 = jnp.arange(24.).reshape(2, 3, 4) >>> x2 = jnp.ones((3, 4, 5)) >>> jnp.linalg.tensordot(x1, x2) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
當將軸指定為明確序列時的等效結果
>>> jnp.linalg.tensordot(x1, x2, axes=([1, 2], [0, 1])) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
透過
einsum()
的等效結果>>> jnp.einsum('ijk,jkm->im', x1, x2) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
對於二維輸入,設定
axes=1
等同於矩陣乘法>>> x1 = jnp.array([[1, 2], ... [3, 4]]) >>> x2 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.tensordot(x1, x2, axes=1) Array([[ 9, 12, 15], [19, 26, 33]], dtype=int32) >>> x1 @ x2 Array([[ 9, 12, 15], [19, 26, 33]], dtype=int32)
對於一維輸入,設定
axes=0
等同於jax.numpy.linalg.outer()
>>> x1 = jnp.array([1, 2]) >>> x2 = jnp.array([1, 2, 3]) >>> jnp.linalg.tensordot(x1, x2, axes=0) Array([[1, 2, 3], [2, 4, 6]], dtype=int32) >>> jnp.linalg.outer(x1, x2) Array([[1, 2, 3], [2, 4, 6]], dtype=int32)