jax.nn.initializers.variance_scaling#

jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)[source]#

初始化器,可使其縮放比例適應權重張量的形狀。

使用 distribution="truncated_normal"distribution="normal",樣本是從(截斷的)常態分佈中抽取的,平均值為零,標準差(如果適用,截斷後)為 \(\sqrt{\frac{scale}{n}}\),其中 n

  • 權重張量中的輸入單元數,如果 mode="fan_in"

  • 輸出單元數,如果 mode="fan_out",或

  • 輸入和輸出單元數的平均值,如果 mode="fan_avg"

此初始化器可以使用 in_axisout_axisbatch_axis 進行配置,以用於一般卷積層或密集層;未在任何這些參數中的軸會被假定為「感受野」(卷積核空間軸)。

使用 distribution="truncated_normal",樣本的絕對值會在縮放之前截斷在 2 個標準差處。

使用 distribution="uniform",樣本是從以下位置抽取的

  • 均勻間隔,如果 dtype 是實數,或

  • 均勻圓盤,如果 dtype 是複數,

平均值為零,標準差為 \(\sqrt{\frac{scale}{n}}\),其中 n 如上定義。

參數:
  • scale (RealNumeric) – 縮放因子(正浮點數)。

  • mode (Literal['fan_in'] | Literal['fan_out'] | Literal['fan_avg']) – "fan_in""fan_out""fan_avg" 其中之一。

  • distribution (Literal['truncated_normal'] | Literal['normal'] | Literal['uniform']) – 要使用的隨機分佈。"truncated_normal""normal""uniform" 其中之一。

  • in_axis (int | Sequence[int]) – 權重陣列中輸入維度的軸或軸序列。

  • out_axis (int | Sequence[int]) – 權重陣列中輸出維度的軸或軸序列。

  • batch_axis (Sequence[int]) – 權重陣列中應忽略的軸或軸序列。

  • dtype (DTypeLikeInexact) – 權重的 dtype。

返回類型:

初始化器