jax.example_libraries.optimizers
模組#
如何使用 JAX 撰寫最佳化器的範例。
您可能不想要匯入這個模組!這個函式庫中的最佳化器僅作為範例。如果您正在尋找功能完整最佳化器函式庫,請考慮 Optax。
這個模組包含一些方便的最佳化器定義,特別是初始化和更新函式,這些函式可以與 ndarray 或任意巢狀 tuple/list/dict 的 ndarray 一起使用。
最佳化器建模為 (init_fun, update_fun, get_params)
函式三元組,其中元件函式具有以下簽名
init_fun(params)
Args:
params: pytree representing the initial parameters.
Returns:
A pytree representing the initial optimizer state, which includes the
initial parameters and may also include auxiliary values like initial
momentum. The optimizer state pytree structure generally differs from that
of `params`.
update_fun(step, grads, opt_state)
Args:
step: integer representing the step index.
grads: a pytree with the same structure as `get_params(opt_state)`
representing the gradients to be used in updating the optimizer state.
opt_state: a pytree representing the optimizer state to be updated.
Returns:
A pytree with the same structure as the `opt_state` argument representing
the updated optimizer state.
get_params(opt_state)
Args:
opt_state: pytree representing an optimizer state.
Returns:
A pytree representing the parameters extracted from `opt_state`, such that
the invariant `params == get_params(init_fun(params))` holds true.
請注意,最佳化器實作在 opt_state 的形式上具有很大的彈性:它只需要是 JaxTypes 的 pytree(以便它可以傳遞到 api.py 中定義的 JAX 轉換),並且它必須可被 update_fun 和 get_params 使用。
使用範例
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)
def step(step, opt_state):
value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
opt_state = opt_update(step, grads, opt_state)
return value, opt_state
for i in range(num_steps):
value, opt_state = step(i, opt_state)
- class jax.example_libraries.optimizers.JoinPoint(subtree)[source]#
基底:
object
標記兩個已加入(巢狀)pytrees 之間的邊界。
- class jax.example_libraries.optimizers.Optimizer(init_fn, update_fn, params_fn)[source]#
基底:
NamedTuple
- 參數:
init_fn (InitFn)
update_fn (UpdateFn)
params_fn (ParamsFn)
- init_fn: InitFn#
欄位編號 0 的別名
- params_fn: ParamsFn#
欄位編號 2 的別名
- update_fn: UpdateFn#
欄位編號 1 的別名
- class jax.example_libraries.optimizers.OptimizerState(packed_state, tree_def, subtree_defs)#
基底:
tuple
- packed_state#
欄位編號 0 的別名
- subtree_defs#
欄位編號 2 的別名
- tree_def#
欄位編號 1 的別名
- jax.example_libraries.optimizers.adagrad(step_size, momentum=0.9)[source]#
建構 Adagrad 的最佳化器三元組。
線上學習和隨機最佳化的自適應次梯度方法:http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
- 參數:
step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。
momentum – 選用,動量的正純量值
- 傳回:
(init_fun, update_fun, get_params) 三元組。
- jax.example_libraries.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]#
建構 Adam 的最佳化器三元組。
- 參數:
step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。
b1 – 選用,beta_1 的正純量值,第一個動量估計值的指數衰減率(預設值為 0.9)。
b2 – 選用,beta_2 的正純量值,第二個動量估計值的指數衰減率(預設值為 0.999)。
eps – 選用,epsilon 的正純量值,用於數值穩定性的小常數(預設值為 1e-8)。
- 傳回:
(init_fun, update_fun, get_params) 三元組。
- jax.example_libraries.optimizers.adamax(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]#
建構 AdaMax 的最佳化器三元組(Adam 的變體,基於無窮範數)。
- 參數:
step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。
b1 – 選用,beta_1 的正純量值,第一個動量估計值的指數衰減率(預設值為 0.9)。
b2 – 選用,beta_2 的正純量值,第二個動量估計值的指數衰減率(預設值為 0.999)。
eps – 選用,epsilon 的正純量值,用於數值穩定性的小常數(預設值為 1e-8)。
- 傳回:
(init_fun, update_fun, get_params) 三元組。
- jax.example_libraries.optimizers.clip_grads(grad_tree, max_norm)[source]#
將儲存為陣列 pytree 的梯度裁剪到最大範數 max_norm。
- jax.example_libraries.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)[source]#
- jax.example_libraries.optimizers.make_schedule(scalar_or_schedule)[source]#
- 參數:
scalar_or_schedule (float | Schedule)
- 傳回類型:
Schedule
- jax.example_libraries.optimizers.momentum(step_size, mass)[source]#
建構具有動量的 SGD 的最佳化器三元組。
- 參數:
step_size (Schedule) – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。
mass (float) – 表示動量係數的正純量。
- 傳回:
(init_fun, update_fun, get_params) 三元組。
- jax.example_libraries.optimizers.nesterov(step_size, mass)[source]#
建構具有 Nesterov 動量的 SGD 的最佳化器三元組。
- 參數:
step_size (Schedule) – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。
mass (float) – 表示動量係數的正純量。
- 傳回:
(init_fun, update_fun, get_params) 三元組。
- jax.example_libraries.optimizers.optimizer(opt_maker)[source]#
裝飾器,使針對陣列定義的最佳化器通用化到容器。
使用此裝飾器,您可以撰寫僅對單個陣列進行操作的 init、update 和 get_params 函式,並將它們轉換為對參數 pytree 進行操作的對應函式。 有關範例,請參閱 optimizers.py 中定義的最佳化器。
- 參數:
opt_maker (Callable[..., tuple[Callable[[Params], State], Callable[[Step, Updates, Params], Params], Callable[[State], Params]]]) –
一個函式,傳回
(init_fun, update_fun, get_params)
函式三元組,這些函式可能僅適用於 ndarray,如同init_fun :: ndarray -> OptStatePytree ndarray update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray get_params :: OptStatePytree ndarray -> ndarray
- 傳回:
一個
(init_fun, update_fun, get_params)
函式三元組,適用於任意 pytree,如同init_fun :: ParameterPytree ndarray -> OptimizerState update_fun :: OptimizerState -> OptimizerState get_params :: OptimizerState -> ParameterPytree ndarray
傳回的函式使用的 OptimizerState pytree 類型與
ParameterPytree (OptStatePytree ndarray)
同構,但可能會將狀態儲存為例如部分展平的資料結構以提高效能。- 傳回類型:
Callable[…, Optimizer]
- jax.example_libraries.optimizers.pack_optimizer_state(marked_pytree)[source]#
將標記的 pytree 轉換為 OptimizerState。
unpack_optimizer_state 的反向操作。 將外 pytree 的葉子表示為 JoinPoint 的標記 pytree 轉換回 OptimizerState。 此函式旨在於還原序列化最佳化器狀態時很有用。
- 參數:
marked_pytree – 包含 JoinPoint 葉子的 pytree,這些葉子包含更多 pytree。
- 傳回:
與輸入引數等效的 OptimizerState。
- jax.example_libraries.optimizers.piecewise_constant(boundaries, values)[source]#
- 參數:
boundaries (Any)
values (Any)
- jax.example_libraries.optimizers.polynomial_decay(step_size, decay_steps, final_step_size, power=1.0)[source]#
- jax.example_libraries.optimizers.rmsprop(step_size, gamma=0.9, eps=1e-08)[source]#
建構 RMSProp 的最佳化器三元組。
- 參數:
step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。 gamma:衰減參數。 eps:epsilon 參數。
- 傳回:
(init_fun, update_fun, get_params) 三元組。
- jax.example_libraries.optimizers.rmsprop_momentum(step_size, gamma=0.9, eps=1e-08, momentum=0.9)[source]#
建構具有動量的 RMSProp 的最佳化器三元組。
此最佳化器與 rmsprop 最佳化器分開,因為它需要追蹤其他參數。
- 參數:
step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。
gamma – 衰減參數。
eps – epsilon 參數。
momentum – 動量參數。
- 傳回:
(init_fun, update_fun, get_params) 三元組。
- jax.example_libraries.optimizers.sgd(step_size)[source]#
建構隨機梯度下降的最佳化器三元組。
- 參數:
step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。
- 傳回:
(init_fun, update_fun, get_params) 三元組。
- jax.example_libraries.optimizers.sm3(step_size, momentum=0.9)[source]#
建構 SM3 的最佳化器三元組。
大規模學習的記憶體效率自適應最佳化。https://arxiv.org/abs/1901.11150
- 參數:
step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。
momentum – 選用,動量的正純量值
- 傳回:
(init_fun, update_fun, get_params) 三元組。