jax.numpy.linalg.tensorsolve#
- jax.numpy.linalg.tensorsolve(a, b, axes=None)[原始碼]#
求解張量方程式 a x = b 中的 x。
JAX 版本的
numpy.linalg.tensorsolve()
。- 參數:
- 返回:
陣列 x,使得在重新排序
a
的軸後,tensordot(a, x, x.ndim)
等同於b
。- 返回類型:
範例
>>> key1, key2 = jax.random.split(jax.random.key(8675309)) >>> a = jax.random.normal(key1, shape=(2, 2, 4)) >>> b = jax.random.normal(key2, shape=(2, 2)) >>> x = jnp.linalg.tensorsolve(a, b) >>> x.shape (4,)
現在展示如何使用
tensordot()
,利用x
重建b
>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim) >>> jnp.allclose(b, b_reconstructed) Array(True, dtype=bool)