jax.numpy.linalg.multi_dot#

jax.numpy.linalg.multi_dot(arrays, *, precision=None)[來源]#

有效率地計算陣列序列之間的矩陣乘積。

JAX 版本的 numpy.linalg.multi_dot()

JAX 內部使用 opt_einsum 函式庫來計算最有效率的運算順序。

參數:
  • arrays (Sequence[ArrayLike]) – 陣列序列。除了第一個和最後一個可能是一維之外,所有陣列都必須是二維的。

  • precision (PrecisionLike | None) – 可以是 None (預設),表示後端的預設精確度;或是 Precision 列舉值 (Precision.DEFAULTPrecision.HIGHPrecision.HIGHEST)。

回傳值:

一個陣列,代表等效於 reduce(jnp.matmul, arrays) 的結果,但以最佳順序評估。

回傳型別:

Array

此函式的存在是因為計算矩陣乘法運算序列的成本,可能會因為運算評估的順序而有極大差異。對於單個矩陣乘法,計算矩陣乘積所需的浮點運算 (flops) 數量可以這樣近似

>>> def approx_flops(x, y):
...   # for 2D x and y, with x.shape[1] == y.shape[0]
...   return 2 * x.shape[0] * x.shape[1] * y.shape[1]

假設我們有三個矩陣,想要依序相乘

>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> x = jax.random.normal(key1, shape=(200, 5))
>>> y = jax.random.normal(key2, shape=(5, 100))
>>> z = jax.random.normal(key3, shape=(100, 10))

由於矩陣乘積的結合律,我們可以使用兩種順序來評估乘積 x @ y @ z,並且兩者都會產生等效的輸出,直到浮點精度

>>> result1 = (x @ y) @ z
>>> result2 = x @ (y @ z)
>>> jnp.allclose(result1, result2, atol=1E-4)
Array(True, dtype=bool)

但是這些的計算成本差異很大

>>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z))
(x @ y) @ z flops: 600000
>>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z))
x @ (y @ z) flops: 30000

第二種方法在估計的 flops 方面效率大約高出 20 倍!

multi_dot 是一個會自動為此類問題選擇最快計算路徑的函式

>>> result3 = jnp.linalg.multi_dot([x, y, z])
>>> jnp.allclose(result1, result3, atol=1E-4)
Array(True, dtype=bool)

我們可以使用 JAX 的 預先降低和編譯 工具來估計每種方法的總 flops,並確認 multi_dot 正在選擇更有效率的選項

>>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops']
600000.0
>>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops']
30000.0
>>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops']
30000.0