jax.example_libraries.stax 模組#

Stax 是一個小巧但靈活的神經網路規格庫,從頭開始建構。

您可能不打算匯入此模組! Stax 僅作為範例庫。JAX 有許多其他功能更完整的神經網路庫,包括 Google 的 Flax 和 DeepMind 的 Haiku

jax.example_libraries.stax.AvgPool(window_shape, strides=None, padding='VALID', spec=None)[source]#

用於池化層的層建構函式。

jax.example_libraries.stax.BatchNorm(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, beta_init=<function zeros>, gamma_init=<function ones>)[source]#

用於批次正規化層的層建構函式。

jax.example_libraries.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#

用於一般卷積層的層建構函式。

jax.example_libraries.stax.Conv1DTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#

用於一般轉置卷積層的層建構函式。

jax.example_libraries.stax.ConvTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#

用於一般轉置卷積層的層建構函式。

jax.example_libraries.stax.Dense(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)[source]#

用於密集(完全連接)層的層建構子函式。

jax.example_libraries.stax.Dropout(rate, mode='train')[source]#

用於具有給定比率的 dropout 層的層建構函式。

jax.example_libraries.stax.FanInConcat(axis=-1)[source]#

用於扇入串聯層的層建構函式。

jax.example_libraries.stax.FanOut(num)[source]#

用於扇出層的層建構函式。

jax.example_libraries.stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]#

用於一般卷積層的層建構函式。

jax.example_libraries.stax.GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]#

用於一般轉置卷積層的層建構函式。

jax.example_libraries.stax.MaxPool(window_shape, strides=None, padding='VALID', spec=None)[source]#

用於池化層的層建構函式。

jax.example_libraries.stax.SumPool(window_shape, strides=None, padding='VALID', spec=None)[source]#

用於池化層的層建構函式。

jax.example_libraries.stax.elementwise(fun, **fun_kwargs)[source]#

在其輸入上逐元素應用純量函數的層。

jax.example_libraries.stax.parallel(*layers)[source]#

用於平行組合層的組合器。

由此組合器產生的層通常與 FanOut 和 FanInSum 層一起使用。

參數:

*layers – 一系列層,每個層都是一個 (init_fun, apply_fun) 對。

返回:

一個新的層,表示一個 (init_fun, apply_fun) 對,代表給定層序列的平行組合。特別是,返回的層接受一系列輸入,並返回與參數 layers 長度相同的輸出序列。

jax.example_libraries.stax.serial(*layers)[source]#

用於串列組合層的組合器。

參數:

*layers – 一系列層,每個層都是一個 (init_fun, apply_fun) 對。

返回:

一個新的層,表示一個 (init_fun, apply_fun) 對,代表給定層序列的串列組合。

jax.example_libraries.stax.shape_dependent(make_layer)[source]#

組合器,用於延遲層建構子對,直到輸入形狀已知。

參數:

make_layer – 一個單參數函數,它接受輸入形狀作為參數(正整數的元組),並返回一個 (init_fun, apply_fun) 對。

返回:

一個新的層,表示一個 (init_fun, apply_fun) 對,表示與 make_layer 返回的層相同的層,但其建構被延遲到輸入形狀已知時。