jax.ops.segment_min#
- jax.ops.segment_min(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[原始碼]#
計算陣列分段內的最小值。
類似於 TensorFlow 的 segment_min
- 參數:
data (ArrayLike) – 要縮減值的陣列。
segment_ids (ArrayLike) – 一個整數 dtype 的陣列,指示要縮減的 data 段(沿其前導軸)。值可以重複,並且不需要排序。超出範圍 [0, num_segments) 的值將被捨棄,並且不影響結果。
num_segments (int | None | None) – 選項,一個非負整數值,指示分段的數量。預設值設定為支援
segment_ids
中所有索引的最小分段數,計算方式為max(segment_ids) + 1
。由於 num_segments 決定了輸出的尺寸,因此必須提供靜態值才能在 JIT 編譯的函數中使用segment_min
。indices_are_sorted (bool) –
segment_ids
是否已知為已排序。unique_indices (bool) – segment_ids 是否已知為沒有重複值。
bucket_size (int | None | None) – 將索引分組到 bucket 中的大小。
segment_min
會在每個 bucket 上分別執行。預設值None
表示不進行 bucketing。mode (lax.GatherScatterMode | None | None) –
jax.lax.GatherScatterMode
值,描述應如何處理超出範圍的索引。依預設,超出範圍 [0, num_segments) 的值將被捨棄,並且不影響總和。
- 返回:
一個形狀為
(num_segments,) + data.shape[1:]
的陣列,表示分段最小值。- 返回類型:
範例
簡單的 1D 分段最小值
>>> data = jnp.arange(6) >>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2]) >>> segment_min(data, segment_ids) Array([0, 2, 4], dtype=int32)
使用 JIT 需要靜態 num_segments
>>> from jax import jit >>> jit(segment_min, static_argnums=2)(data, segment_ids, 3) Array([0, 2, 4], dtype=int32)