jax.numpy.linalg.tensorinv#

jax.numpy.linalg.tensorinv(a, ind=2)[原始碼]#

計算陣列的張量逆矩陣。

JAX 版本的 numpy.linalg.tensorinv()

此函式計算與相同 ind 值之 tensordot() 運算的逆運算。

參數:
  • a (ArrayLike) – 要反轉的陣列。必須具有 prod(a.shape[:ind]) == prod(a.shape[ind:])

  • ind (int) – 指定張量積中索引數量的正整數。

傳回:

形狀為 (*a.shape[ind:], *a.shape[:ind]) 的陣列,包含 a 的張量逆矩陣。

傳回型別:

Array

範例

>>> key = jax.random.key(1337)
>>> x = jax.random.normal(key, shape=(2, 2, 4))
>>> xinv = jnp.linalg.tensorinv(x, 2)
>>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2)
>>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4)
Array(True, dtype=bool)