jax.numpy.cumsum#
- jax.numpy.cumsum(a, axis=None, dtype=None, out=None)[原始碼]#
沿著軸的元素累積總和。
JAX 版本的
numpy.cumsum()
實作。- 參數:
a (ArrayLike) – 要累積的 N 維陣列。
axis (int | None) – 沿著其累積的整數軸。如果為 None (預設值),則陣列將被展平並沿著展平的軸累積。
dtype (DTypeLike | None) – 可選地指定輸出的 dtype。如果未指定,則輸出 dtype 將與輸入 dtype 相符。
out (None) – JAX 未使用
- 返回:
一個包含沿給定軸的累積總和的陣列。
- 返回類型:
參見
jax.numpy.cumulative_sum()
:透過陣列 API 標準的累積總和。jax.numpy.add.accumulate()
:透過 ufunc 方法的累積總和。jax.numpy.nancumsum()
:忽略 NaN 值的累積總和。jax.numpy.sum()
:沿軸的總和
範例
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumsum(x) # flattened cumulative sum Array([ 1, 3, 6, 10, 15, 21], dtype=int32) >>> jnp.cumsum(x, axis=1) # cumulative sum along axis 1 Array([[ 1, 3, 6], [ 4, 9, 15]], dtype=int32)