jax.nn.softmax#

jax.nn.softmax(x, axis=-1, where=None, initial=_UNSPECIFIED)[source]#

Softmax 函數。

計算將元素重新縮放至 \([0, 1]\) 範圍的函數,使得沿著 axis 的元素總和為 \(1\)

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
參數:
  • x (*ArrayLike*) – 輸入陣列

  • axis (*int** | **tuple*[**int**, *...*]* | *None*) – 應計算 softmax 的軸或軸。在這些維度上加總的 softmax 輸出應總和為 \(1\)。可以是整數或整數元組。

  • where (*ArrayLike* | *None* | *None*) – 要包含在 softmax 中的元素。

  • initial (*Unspecified*)

傳回:

一個陣列。

傳回型別:

Array

注意

如果任何輸入值為 +inf,結果將會是全部 NaN:這反映了 inf / inf 在浮點數運算中未明確定義的事實。

另請參閱

log_softmax()