jax.nn.initializers 模組#

常見的神經網路層初始化器,與 Keras 和 Sonnet 中使用的定義一致。

初始化器#

此模組提供常見的神經網路層初始化器,與 Keras 和 Sonnet 中使用的定義一致。

初始化器是一個函數,它接受三個參數:(key, shape, dtype) 並返回一個具有維度 shape 和資料類型 dtype 的陣列。參數 key 是一個 PRNG 金鑰(例如來自 jax.random.key()),用於產生隨機數字以初始化陣列。

constant(value[, dtype])

建立一個初始化器,它返回充滿常數 value 的陣列。

delta_orthogonal([scale, column_axis, dtype])

為 delta 正交核建立初始化器。

glorot_normal([in_axis, out_axis, ...])

建立 Glorot 正規化初始化器(又名 Xavier 正規化初始化器)。

glorot_uniform([in_axis, out_axis, ...])

建立 Glorot 均勻初始化器(又名 Xavier 均勻初始化器)。

he_normal([in_axis, out_axis, batch_axis, dtype])

建立 He 正規化初始化器(又名 Kaiming 正規化初始化器)。

he_uniform([in_axis, out_axis, batch_axis, ...])

建立 He 均勻初始化器(又名 Kaiming 均勻初始化器)。

lecun_normal([in_axis, out_axis, ...])

建立 Lecun 正規化初始化器。

lecun_uniform([in_axis, out_axis, ...])

建立 Lecun 均勻初始化器。

normal([stddev, dtype])

建立一個初始化器,它返回實數常態分佈的隨機陣列。

ones(key, shape[, dtype])

一個初始化器,它返回充滿 1 的常數陣列。

orthogonal([scale, column_axis, dtype])

建立一個初始化器,它返回均勻分佈的正交矩陣。

truncated_normal([stddev, dtype, lower, upper])

建立一個初始化器,它返回截斷常態分佈的隨機陣列。

uniform([scale, dtype])

建立一個初始化器,它返回實數均勻分佈的隨機陣列。

variance_scaling(scale, mode, distribution)

初始化器,使其尺度適應權重張量的形狀。

zeros(key, shape[, dtype])

一個初始化器,它返回充滿零的常數陣列。