jax.nn 模組#

神經網路函式庫的常用函數。

激活函數#

relu

線性整流單元激活函數。

relu6

線性整流單元 6 激活函數。

sigmoid(x)

Sigmoid 激活函數。

softplus(x)

Softplus 激活函數。

sparse_plus(x)

Sparse plus 函數。

sparse_sigmoid(x)

Sparse sigmoid 激活函數。

soft_sign(x)

Soft-sign 激活函數。

silu(x)

SiLU (又稱 swish) 激活函數。

swish(x)

SiLU (又稱 swish) 激活函數。

log_sigmoid(x)

Log-sigmoid 激活函數。

leaky_relu(x[, negative_slope])

Leaky 線性整流單元激活函數。

hard_sigmoid(x)

Hard Sigmoid 激活函數。

hard_silu(x)

Hard SiLU (swish) 激活函數

hard_swish(x)

Hard SiLU (swish) 激活函數

hard_tanh(x)

Hard \(\mathrm{tanh}\) 激活函數。

elu(x[, alpha])

指數線性單元激活函數。

celu(x[, alpha])

連續可微分指數線性單元激活。

selu(x)

縮放指數線性單元激活。

gelu(x[, approximate])

高斯誤差線性單元激活函數。

glu(x[, axis])

門控線性單元激活函數。

squareplus(x[, b])

Squareplus 激活函數。

mish(x)

Mish 激活函數。

其他函數#

softmax(x[, axis, where, initial])

Softmax 函數。

log_softmax(x[, axis, where, initial])

Log-Softmax 函數。

logsumexp()

Log-sum-exp 約簡。

standardize(x[, axis, mean, variance, ...])

通過減去 mean 並除以 \(\sqrt{\mathrm{variance}}\) 來正規化陣列。

one_hot(x, num_classes, *[, dtype, axis])

One-hot 編碼給定的索引。

dot_product_attention(query, key, value[, ...])

縮放點積注意力函數。