jax.numpy.logaddexp#

jax.numpy.logaddexp = <jnp.ufunc 'logaddexp'>#

計算 log(exp(x1) + exp(x2)),避免溢位。

JAX 版本的 numpy.logaddexp

參數:
  • x1 – 輸入陣列

  • x2 – 輸入陣列

  • args (ArrayLike)

  • out (None)

  • where (None)

傳回值:

包含結果的陣列。

傳回類型:

Any

範例

>>> x1 = jnp.array([1, 2, 3])
>>> x2 = jnp.array([4, 5, 6])
>>> result1 = jnp.logaddexp(x1, x2)
>>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2))
>>> print(jnp.allclose(result1, result2))
True