jax.experimental.sparse.bcoo_dot_general_sampled#

jax.experimental.sparse.bcoo_dot_general_sampled(A, B, indices, *, dimension_numbers)[source]#

在給定的稀疏索引處計算輸出的收縮運算。

參數:
  • lhs – 一個 ndarray。

  • rhs – 一個 ndarray。

  • indices (Array) – BCOO 索引。

  • dimension_numbers (DotDimensionNumbers) – 形式為 ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) 的元組。

  • A (Array)

  • B (Array)

回傳:

BCOO 資料,一個包含結果的 ndarray。

回傳型別:

Array