jax.ops 模組#

函式 jax.ops.index_updatejax.ops.index_add 等在 JAX 0.2.22 中已棄用,現已移除。請改用 JAX 陣列上的 jax.numpy.ndarray.at 屬性。

區段縮減運算子#

segment_max(data, segment_ids[, ...])

計算陣列區段內的最大值。

segment_min(data, segment_ids[, ...])

計算陣列區段內的最小值。

segment_prod(data, segment_ids[, ...])

計算陣列區段內的乘積。

segment_sum(data, segment_ids[, ...])

計算陣列區段內的總和。