jax.scipy.special.softmax#

jax.scipy.special.softmax(x, /, *, axis=None)[原始碼]#

Softmax 函數。

JAX 實現的 scipy.special.softmax()

計算將元素重新縮放到範圍 \([0, 1]\) 的函數,使得沿 axis 的元素總和為 \(1\)

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
參數:
  • x (ArrayLike) – 輸入陣列

  • axis (int | tuple[int, ...] | None | None) – 應計算 softmax 的軸或軸。在這些維度上加總的 softmax 輸出應總和為 \(1\)

傳回:

x 形狀相同的陣列。

傳回型別:

陣列

注意

如果任何輸入值為 +inf,則結果將全為 NaN:這反映了 inf / inf 在浮點數數學的上下文中未明確定義的事實。

另請參閱

log_softmax()