jax.numpy.cumprod#
- jax.numpy.cumprod(a, axis=None, dtype=None, out=None)[原始碼]#
沿著軸的元素累積乘積。
numpy.cumprod()
的 JAX 實作。- 參數:
a (ArrayLike) – 要累積的 N 維陣列。
axis (int | None) – 要沿著累積的整數軸。如果為 None (預設值),則陣列將被展平並沿著展平的軸累積。
dtype (DTypeLike | None) – 可選地指定輸出的 dtype。如果未指定,則輸出 dtype 將與輸入 dtype 相符。
out (None) – JAX 未使用
- 返回:
一個包含沿給定軸的累積乘積的陣列。
- 返回類型:
參見
jax.numpy.multiply.accumulate()
:透過 ufunc 方法的累積乘積。jax.numpy.nancumprod()
:忽略 NaN 值的累積乘積。jax.numpy.prod()
:沿軸的乘積
範例
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumprod(x) # flattened cumulative product Array([ 1, 2, 6, 24, 120, 720], dtype=int32) >>> jnp.cumprod(x, axis=1) # cumulative product along axis 1 Array([[ 1, 2, 6], [ 4, 20, 120]], dtype=int32)