jax.numpy.logaddexp2#

jax.numpy.logaddexp2 = <jnp.ufunc 'logaddexp2'>#

以 2 為底的輸入指數和的對數,避免溢位。

numpy.logaddexp2 的 JAX 實作。

參數:
  • x1 – 輸入陣列或純量。

  • x2 – 輸入陣列或純量。x1x2 應具有相同形狀或廣播相容性。

  • args (ArrayLike)

  • out (None)

  • where (None)

回傳:

一個包含結果的陣列,\(log_2(2^{x1}+2^{x2})\),逐元素運算。

回傳型別:

Any

另請參閱

範例

>>> x1 = jnp.array([[3, -1, 4],
...                 [8, 5, -2]])
>>> x2 = jnp.array([2, 3, -5])
>>> result1 = jnp.logaddexp2(x1, x2)
>>> result2 = jnp.log2(jnp.exp2(x1) + jnp.exp2(x2))
>>> jnp.allclose(result1, result2)
Array(True, dtype=bool)