jax.numpy.nancumsum#
- jax.numpy.nancumsum(a, axis=None, dtype=None, out=None)[原始碼]#
沿著軸向的元素累加總和,忽略 NaN 值。
JAX 實作的
numpy.nancumsum()
。- 參數:
a (ArrayLike) – 要累加的 N 維陣列。
axis (int | None) – 沿著哪個軸向累加的整數軸。如果為 None (預設值),則陣列將被展平,並沿著展平的軸向累加。
dtype (DTypeLike | None) – 可選地指定輸出的 dtype。如果未指定,則輸出 dtype 將與輸入 dtype 相符。
out (None) – JAX 未使用
- 回傳:
一個包含沿著給定軸向累加總和的陣列。
- 回傳類型:
參見
jax.numpy.cumsum()
:不忽略 NaN 值的累加總和。jax.numpy.cumulative_sum()
:透過陣列 API 標準的累加總和。jax.numpy.add.accumulate()
:透過 ufunc 方法的累加總和。jax.numpy.sum()
:沿軸向的總和
範例
>>> x = jnp.array([[1., 2., jnp.nan], ... [4., jnp.nan, 6.]])
標準累加總和將傳播 NaN 值
>>> jnp.cumsum(x) Array([ 1., 3., nan, nan, nan, nan], dtype=float32)
nancumsum()
將忽略 NaN 值,有效地將其替換為零>>> jnp.nancumsum(x) Array([ 1., 3., 3., 7., 7., 13.], dtype=float32)
沿軸 1 的累加總和
>>> jnp.nancumsum(x, axis=1) Array([[ 1., 3., 3.], [ 4., 4., 10.]], dtype=float32)