jax.numpy.einsum_path#

jax.numpy.einsum_path(subscripts, /, *operands, optimize='auto')[原始碼]#

評估最佳收縮路徑,而不評估 einsum。

numpy.einsum_path() 的 JAX 實作。此函數調用 opt_einsum 套件,並使用其最佳化例程。

參數:
  • subscripts – 包含以逗號分隔的軸名稱的字串。

  • *operands – 對應於 subscripts 的一個或多個陣列的序列。

  • optimize (bool | str | list[tuple[int, ...]]) – 指定如何最佳化計算順序。在 JAX 中,這預設為 "auto"。其他選項包括 True (與 "optimize" 相同)、False (未最佳化),或 opt_einsum 支援的任何字串,其中包括 "optimize""greedy""eager" 等。

返回:

一個元組,包含可能傳遞給 einsum() 的路徑,以及表示此最佳路徑的可列印物件。

返回類型:

tuple[list[tuple[int, …]], Any]

範例

>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> x = jax.random.randint(key1, minval=-5, maxval=5, shape=(2, 3))
>>> y = jax.random.randint(key2, minval=-5, maxval=5, shape=(3, 100))
>>> z = jax.random.randint(key3, minval=-5, maxval=5, shape=(100, 5))
>>> path, path_info = jnp.einsum_path("ij,jk,kl", x, y, z, optimize="optimal")
>>> print(path)
[(1, 2), (0, 1)]
>>> print(path_info)
      Complete contraction:  ij,jk,kl->il
            Naive scaling:  4
        Optimized scaling:  3
          Naive FLOP count:  9.000e+3
      Optimized FLOP count:  3.060e+3
      Theoretical speedup:  2.941e+0
      Largest intermediate:  1.500e+1 elements
    --------------------------------------------------------------------------------
    scaling        BLAS                current                             remaining
    --------------------------------------------------------------------------------
      3           GEMM              kl,jk->lj                             ij,lj->il
      3           GEMM              lj,ij->il                                il->il

einsum() 中使用計算路徑

>>> jnp.einsum("ij,jk,kl", x, y, z, optimize=path)
Array([[-754,  324, -142,   82,   50],
       [ 408,  -50,   87,  -29,    7]], dtype=int32)