jax.scipy.special.logsumexp#

jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: Literal[False] = False, where: ArrayLike | None = None) Array[source]#
jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, *, return_sign: Literal[True], where: ArrayLike | None = None) tuple[Array, Array]
jax.scipy.special.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) Array | tuple[Array, Array]

Log-sum-exp 歸約。

scipy.special.logsumexp() 的 JAX 實作。

\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]

其中 j 索引範圍涵蓋一個或多個要歸約的維度。

參數:
  • a – 輸入陣列

  • axis – 要歸約的軸。可以是 None、整數或整數元組。

  • b\(\mathrm{exp}(a)\) 的縮放因子。必須可廣播到 a 的形狀。

  • keepdims – 如果為 True,則歸約的軸會保留在輸出中,作為大小為 1 的維度。

  • return_sign – 如果為 True,則輸出將為 (result, sign) 對,其中 sign 是總和的符號,而 result 包含其絕對值的對數。如果為 False,則僅傳回 result,如果總和為負數,則它將包含 NaN 值。

  • where – 要包含在歸約中的元素。

傳回:

根據 return_sign 參數的值,傳回陣列 result 或陣列對 (result, sign)