jax.lax.associative_scan#
- jax.lax.associative_scan(fn, elems, reverse=False, axis=0)[source]#
使用結合性二元運算平行執行掃描。
關於結合性掃描的介紹,請參閱 [BLE1990]。
- 參數:
fn (Callable) –
實作結合性二元運算的 Python 可呼叫物件,簽名為
r = fn(a, b)
。函數 fn 必須是結合性的,即它必須滿足方程式fn(a, fn(b, c)) == fn(fn(a, b), c)
。輸入和結果是(可能是巢狀 Python 樹狀結構的)陣列,與
elems
相符。每個陣列都有一個維度來代替axis
維度。fn 應以元素方式套用在axis
維度上(例如,透過使用jax.vmap()
於元素方式函數上)。結果
r
具有與兩個輸入a
和b
相同的形狀(和結構)。elems – (可能是巢狀 Python 樹狀結構的)陣列,每個陣列都有一個大小為
num_elems
的axis
維度。reverse (bool) – 布林值,指出是否應針對
axis
維度反轉掃描。axis (int) – 識別應在其上執行掃描之軸的整數。
- 傳回:
與
elems
具有相同形狀和結構的(可能是巢狀 Python 樹狀結構的)陣列,其中axis
的第k
個元素是遞迴套用fn
以組合elems
沿axis
的前k
個元素的結果。例如,給定elems = [a, b, c, ...]
,結果將為[a, fn(a, b), fn(fn(a, b), c), ...]
。
範例 1:數字陣列的部分和
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) Array([0, 1, 3, 6], dtype=int32)
範例 2:矩陣陣列的部分乘積
>>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2)) >>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods.shape (4, 2, 2)
範例 3:數字陣列的反向部分和
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True) Array([6, 6, 5, 3], dtype=int32)
[BLE1990]Blelloch, Guy E. 1990. “Prefix Sums and Their Applications.”, Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University.