jax.lax.associative_scan#

jax.lax.associative_scan(fn, elems, reverse=False, axis=0)[source]#

使用結合性二元運算平行執行掃描。

關於結合性掃描的介紹,請參閱 [BLE1990]

參數:
  • fn (Callable) –

    實作結合性二元運算的 Python 可呼叫物件,簽名為 r = fn(a, b)。函數 fn 必須是結合性的,即它必須滿足方程式 fn(a, fn(b, c)) == fn(fn(a, b), c)

    輸入和結果是(可能是巢狀 Python 樹狀結構的)陣列,與 elems 相符。每個陣列都有一個維度來代替 axis 維度。fn 應以元素方式套用在 axis 維度上(例如,透過使用 jax.vmap() 於元素方式函數上)。

    結果 r 具有與兩個輸入 ab 相同的形狀(和結構)。

  • elems – (可能是巢狀 Python 樹狀結構的)陣列,每個陣列都有一個大小為 num_elemsaxis 維度。

  • reverse (bool) – 布林值,指出是否應針對 axis 維度反轉掃描。

  • axis (int) – 識別應在其上執行掃描之軸的整數。

傳回:

elems 具有相同形狀和結構的(可能是巢狀 Python 樹狀結構的)陣列,其中 axis 的第 k 個元素是遞迴套用 fn 以組合 elems 沿 axis 的前 k 個元素的結果。例如,給定 elems = [a, b, c, ...],結果將為 [a, fn(a, b), fn(fn(a, b), c), ...]

範例 1:數字陣列的部分和

>>> lax.associative_scan(jnp.add, jnp.arange(0, 4))
Array([0, 1, 3, 6], dtype=int32)

範例 2:矩陣陣列的部分乘積

>>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2))
>>> partial_prods = lax.associative_scan(jnp.matmul, mats)
>>> partial_prods.shape
(4, 2, 2)

範例 3:數字陣列的反向部分和

>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
Array([6, 6, 5, 3], dtype=int32)
[BLE1990]

Blelloch, Guy E. 1990. “Prefix Sums and Their Applications.”, Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University.