jax.lax.batch_matmul#

jax.lax.batch_matmul(lhs, rhs, precision=None)[source]#

批次矩陣乘法。

參數:
  • lhs (Array)

  • rhs (Array)

  • precision (PrecisionLike | None)

回傳型別:

Array