jax.numpy.logaddexp#
- jax.numpy.logaddexp = <jnp.ufunc 'logaddexp'>#
計算
log(exp(x1) + exp(x2))
,避免溢位。JAX 版本的
numpy.logaddexp
- 參數:
x1 – 輸入陣列
x2 – 輸入陣列
args (ArrayLike)
out (None)
where (None)
- 傳回值:
包含結果的陣列。
- 傳回類型:
Any
範例
>>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> result1 = jnp.logaddexp(x1, x2) >>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2)) >>> print(jnp.allclose(result1, result2)) True