jax.nn.logsumexp#
- jax.nn.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: Literal[False] = False, where: ArrayLike | None = None) Array [原始碼]#
- jax.nn.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, *, return_sign: Literal[True], where: ArrayLike | None = None) tuple[Array, Array]
- jax.nn.logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) Array | tuple[Array, Array]
Log-sum-exp 約簡。
scipy.special.logsumexp()
的 JAX 實作。\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]其中 \(j\) 索引範圍涵蓋一個或多個要約簡的維度。
- 參數:
a – 輸入陣列
axis – 要在其中約簡的軸或多個軸。可以是
None
、整數或整數元組。b – \(\mathrm{exp}(a)\) 的縮放係數。必須可廣播到 a 的形狀。
keepdims – 如果為
True
,則約簡的軸會保留在輸出中,作為大小為 1 的維度。return_sign – 如果為
True
,則輸出將為(result, sign)
配對,其中sign
是總和的符號,而result
包含其絕對值的對數。如果為False
,則僅傳回result
,如果總和為負數,則它將包含 NaN 值。where – 要包含在約簡中的元素。
- 傳回:
陣列
result
或陣列配對(result, sign)
,取決於return_sign
引數的值。