jax.numpy.linalg.tensorsolve#

jax.numpy.linalg.tensorsolve(a, b, axes=None)[原始碼]#

求解張量方程式 a x = b 中的 x。

JAX 版本的 numpy.linalg.tensorsolve()

參數:
  • a (ArrayLike) – 輸入陣列。透過 axes 重新排序後(見下方),形狀必須為 (*b.shape, *x.shape)

  • b (ArrayLike) – 右側陣列。

  • axes (tuple[int, ...] | None | None) – 可選元組,指定 a 中應移至末尾的軸

返回:

陣列 x,使得在重新排序 a 的軸後,tensordot(a, x, x.ndim) 等同於 b

返回類型:

Array

範例

>>> key1, key2 = jax.random.split(jax.random.key(8675309))
>>> a = jax.random.normal(key1, shape=(2, 2, 4))
>>> b = jax.random.normal(key2, shape=(2, 2))
>>> x = jnp.linalg.tensorsolve(a, b)
>>> x.shape
(4,)

現在展示如何使用 tensordot(),利用 x 重建 b

>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim)
>>> jnp.allclose(b, b_reconstructed)
Array(True, dtype=bool)