jax.numpy.linalg.tensorinv#
- jax.numpy.linalg.tensorinv(a, ind=2)[原始碼]#
計算陣列的張量逆矩陣。
JAX 版本的
numpy.linalg.tensorinv()
。此函式計算與相同
ind
值之tensordot()
運算的逆運算。- 參數:
a (ArrayLike) – 要反轉的陣列。必須具有
prod(a.shape[:ind]) == prod(a.shape[ind:])
ind (int) – 指定張量積中索引數量的正整數。
- 傳回:
形狀為
(*a.shape[ind:], *a.shape[:ind])
的陣列,包含a
的張量逆矩陣。- 傳回型別:
範例
>>> key = jax.random.key(1337) >>> x = jax.random.normal(key, shape=(2, 2, 4)) >>> xinv = jnp.linalg.tensorinv(x, 2) >>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2) >>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4) Array(True, dtype=bool)