jax.example_libraries.stax
模組#
Stax 是一個小巧但靈活的神經網路規格庫,從頭開始建構。
您可能不打算匯入此模組! Stax 僅作為範例庫。JAX 有許多其他功能更完整的神經網路庫,包括 Google 的 Flax 和 DeepMind 的 Haiku。
- jax.example_libraries.stax.AvgPool(window_shape, strides=None, padding='VALID', spec=None)[source]#
用於池化層的層建構函式。
- jax.example_libraries.stax.BatchNorm(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, beta_init=<function zeros>, gamma_init=<function ones>)[source]#
用於批次正規化層的層建構函式。
- jax.example_libraries.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#
用於一般卷積層的層建構函式。
- jax.example_libraries.stax.Conv1DTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#
用於一般轉置卷積層的層建構函式。
- jax.example_libraries.stax.ConvTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#
用於一般轉置卷積層的層建構函式。
- jax.example_libraries.stax.Dense(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)[source]#
用於密集(完全連接)層的層建構子函式。
- jax.example_libraries.stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]#
用於一般卷積層的層建構函式。
- jax.example_libraries.stax.GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[source]#
用於一般轉置卷積層的層建構函式。
- jax.example_libraries.stax.MaxPool(window_shape, strides=None, padding='VALID', spec=None)[source]#
用於池化層的層建構函式。
- jax.example_libraries.stax.SumPool(window_shape, strides=None, padding='VALID', spec=None)[source]#
用於池化層的層建構函式。
- jax.example_libraries.stax.parallel(*layers)[source]#
用於平行組合層的組合器。
由此組合器產生的層通常與 FanOut 和 FanInSum 層一起使用。
- 參數:
*layers – 一系列層,每個層都是一個 (init_fun, apply_fun) 對。
- 返回:
一個新的層,表示一個 (init_fun, apply_fun) 對,代表給定層序列的平行組合。特別是,返回的層接受一系列輸入,並返回與參數 layers 長度相同的輸出序列。