jax.ops.segment_prod#

jax.ops.segment_prod(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[原始碼]#

計算陣列分段內的乘積。

類似於 TensorFlow 的 segment_prod

參數:
  • data (ArrayLike) – 包含要縮減值的陣列。

  • segment_ids (ArrayLike) – 一個整數 dtype 陣列,指示要縮減的 data 段(沿其前導軸)。值可以重複,並且不需要排序。範圍 [0, num_segments) 之外的值將被捨棄,且不計入結果。

  • num_segments (int | None | None) – 選填,一個非負整數值,指示分段的數量。預設設定為支援 segment_ids 中所有索引的最小分段數,計算方式為 max(segment_ids) + 1。由於 num_segments 決定了輸出的尺寸,因此必須提供靜態值才能在 JIT 編譯的函數中使用 segment_prod

  • indices_are_sorted (bool) – segment_ids 是否已知已排序。

  • unique_indices (bool) – segment_ids 是否已知沒有重複項。

  • bucket_size (int | None | None) – 將索引分組到的 bucket 大小。segment_prod 在每個 bucket 上分別執行,以提高加法的數值穩定性。預設值 None 表示不進行 bucketing。

  • mode (lax.GatherScatterMode | None | None) – jax.lax.GatherScatterMode 值,描述應如何處理超出範圍的索引。預設情況下,範圍 [0, num_segments) 之外的值將被捨棄,且不計入總和。

傳回:

形狀為 (num_segments,) + data.shape[1:] 的陣列,表示分段乘積。

傳回類型:

Array

範例

簡單的 1D 分段乘積

>>> data = jnp.arange(6)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_prod(data, segment_ids)
Array([ 0,  6, 20], dtype=int32)

使用 JIT 需要靜態 num_segments

>>> from jax import jit
>>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3)
Array([ 0,  6, 20], dtype=int32)