jax.example_libraries.optimizers 模組#

如何使用 JAX 撰寫最佳化器的範例。

您可能不想要匯入這個模組!這個函式庫中的最佳化器僅作為範例。如果您正在尋找功能完整最佳化器函式庫,請考慮 Optax

這個模組包含一些方便的最佳化器定義,特別是初始化和更新函式,這些函式可以與 ndarray 或任意巢狀 tuple/list/dict 的 ndarray 一起使用。

最佳化器建模為 (init_fun, update_fun, get_params) 函式三元組,其中元件函式具有以下簽名

init_fun(params)

Args:
  params: pytree representing the initial parameters.

Returns:
  A pytree representing the initial optimizer state, which includes the
  initial parameters and may also include auxiliary values like initial
  momentum. The optimizer state pytree structure generally differs from that
  of `params`.
update_fun(step, grads, opt_state)

Args:
  step: integer representing the step index.
  grads: a pytree with the same structure as `get_params(opt_state)`
    representing the gradients to be used in updating the optimizer state.
  opt_state: a pytree representing the optimizer state to be updated.

Returns:
  A pytree with the same structure as the `opt_state` argument representing
  the updated optimizer state.
get_params(opt_state)

Args:
  opt_state: pytree representing an optimizer state.

Returns:
  A pytree representing the parameters extracted from `opt_state`, such that
  the invariant `params == get_params(init_fun(params))` holds true.

請注意,最佳化器實作在 opt_state 的形式上具有很大的彈性:它只需要是 JaxTypes 的 pytree(以便它可以傳遞到 api.py 中定義的 JAX 轉換),並且它必須可被 update_fun 和 get_params 使用。

使用範例

opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)

def step(step, opt_state):
  value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
  opt_state = opt_update(step, grads, opt_state)
  return value, opt_state

for i in range(num_steps):
  value, opt_state = step(i, opt_state)
class jax.example_libraries.optimizers.JoinPoint(subtree)[source]#

基底:object

標記兩個已加入(巢狀)pytrees 之間的邊界。

class jax.example_libraries.optimizers.Optimizer(init_fn, update_fn, params_fn)[source]#

基底:NamedTuple

參數:
  • init_fn (InitFn)

  • update_fn (UpdateFn)

  • params_fn (ParamsFn)

init_fn: InitFn#

欄位編號 0 的別名

params_fn: ParamsFn#

欄位編號 2 的別名

update_fn: UpdateFn#

欄位編號 1 的別名

class jax.example_libraries.optimizers.OptimizerState(packed_state, tree_def, subtree_defs)#

基底:tuple

packed_state#

欄位編號 0 的別名

subtree_defs#

欄位編號 2 的別名

tree_def#

欄位編號 1 的別名

jax.example_libraries.optimizers.adagrad(step_size, momentum=0.9)[source]#

建構 Adagrad 的最佳化器三元組。

線上學習和隨機最佳化的自適應次梯度方法:http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf

參數:
  • step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。

  • momentum – 選用,動量的正純量值

傳回:

(init_fun, update_fun, get_params) 三元組。

jax.example_libraries.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]#

建構 Adam 的最佳化器三元組。

參數:
  • step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。

  • b1 – 選用,beta_1 的正純量值,第一個動量估計值的指數衰減率(預設值為 0.9)。

  • b2 – 選用,beta_2 的正純量值,第二個動量估計值的指數衰減率(預設值為 0.999)。

  • eps – 選用,epsilon 的正純量值,用於數值穩定性的小常數(預設值為 1e-8)。

傳回:

(init_fun, update_fun, get_params) 三元組。

jax.example_libraries.optimizers.adamax(step_size, b1=0.9, b2=0.999, eps=1e-08)[source]#

建構 AdaMax 的最佳化器三元組(Adam 的變體,基於無窮範數)。

參數:
  • step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。

  • b1 – 選用,beta_1 的正純量值,第一個動量估計值的指數衰減率(預設值為 0.9)。

  • b2 – 選用,beta_2 的正純量值,第二個動量估計值的指數衰減率(預設值為 0.999)。

  • eps – 選用,epsilon 的正純量值,用於數值穩定性的小常數(預設值為 1e-8)。

傳回:

(init_fun, update_fun, get_params) 三元組。

jax.example_libraries.optimizers.clip_grads(grad_tree, max_norm)[source]#

將儲存為陣列 pytree 的梯度裁剪到最大範數 max_norm

jax.example_libraries.optimizers.constant(step_size)[source]#
傳回類型:

Schedule

jax.example_libraries.optimizers.exponential_decay(step_size, decay_steps, decay_rate)[source]#
jax.example_libraries.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)[source]#
jax.example_libraries.optimizers.l2_norm(tree)[source]#

計算陣列 pytree 的 l2 範數。對於權重衰減很有用。

jax.example_libraries.optimizers.make_schedule(scalar_or_schedule)[source]#
參數:

scalar_or_schedule (float | Schedule)

傳回類型:

Schedule

jax.example_libraries.optimizers.momentum(step_size, mass)[source]#

建構具有動量的 SGD 的最佳化器三元組。

參數:
  • step_size (Schedule) – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。

  • mass (float) – 表示動量係數的正純量。

傳回:

(init_fun, update_fun, get_params) 三元組。

jax.example_libraries.optimizers.nesterov(step_size, mass)[source]#

建構具有 Nesterov 動量的 SGD 的最佳化器三元組。

參數:
  • step_size (Schedule) – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。

  • mass (float) – 表示動量係數的正純量。

傳回:

(init_fun, update_fun, get_params) 三元組。

jax.example_libraries.optimizers.optimizer(opt_maker)[source]#

裝飾器,使針對陣列定義的最佳化器通用化到容器。

使用此裝飾器,您可以撰寫僅對單個陣列進行操作的 init、update 和 get_params 函式,並將它們轉換為對參數 pytree 進行操作的對應函式。 有關範例,請參閱 optimizers.py 中定義的最佳化器。

參數:

opt_maker (Callable[..., tuple[Callable[[Params], State], Callable[[Step, Updates, Params], Params], Callable[[State], Params]]]) –

一個函式,傳回 (init_fun, update_fun, get_params) 函式三元組,這些函式可能僅適用於 ndarray,如同

init_fun :: ndarray -> OptStatePytree ndarray
update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray
get_params :: OptStatePytree ndarray -> ndarray

傳回:

一個 (init_fun, update_fun, get_params) 函式三元組,適用於任意 pytree,如同

init_fun :: ParameterPytree ndarray -> OptimizerState
update_fun :: OptimizerState -> OptimizerState
get_params :: OptimizerState -> ParameterPytree ndarray

傳回的函式使用的 OptimizerState pytree 類型與 ParameterPytree (OptStatePytree ndarray) 同構,但可能會將狀態儲存為例如部分展平的資料結構以提高效能。

傳回類型:

Callable[…, Optimizer]

jax.example_libraries.optimizers.pack_optimizer_state(marked_pytree)[source]#

將標記的 pytree 轉換為 OptimizerState。

unpack_optimizer_state 的反向操作。 將外 pytree 的葉子表示為 JoinPoint 的標記 pytree 轉換回 OptimizerState。 此函式旨在於還原序列化最佳化器狀態時很有用。

參數:

marked_pytree – 包含 JoinPoint 葉子的 pytree,這些葉子包含更多 pytree。

傳回:

與輸入引數等效的 OptimizerState。

jax.example_libraries.optimizers.piecewise_constant(boundaries, values)[source]#
參數:
  • boundaries (Any)

  • values (Any)

jax.example_libraries.optimizers.polynomial_decay(step_size, decay_steps, final_step_size, power=1.0)[source]#
jax.example_libraries.optimizers.rmsprop(step_size, gamma=0.9, eps=1e-08)[source]#

建構 RMSProp 的最佳化器三元組。

參數:

step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。 gamma:衰減參數。 eps:epsilon 參數。

傳回:

(init_fun, update_fun, get_params) 三元組。

jax.example_libraries.optimizers.rmsprop_momentum(step_size, gamma=0.9, eps=1e-08, momentum=0.9)[source]#

建構具有動量的 RMSProp 的最佳化器三元組。

此最佳化器與 rmsprop 最佳化器分開,因為它需要追蹤其他參數。

參數:
  • step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。

  • gamma – 衰減參數。

  • eps – epsilon 參數。

  • momentum – 動量參數。

傳回:

(init_fun, update_fun, get_params) 三元組。

jax.example_libraries.optimizers.sgd(step_size)[source]#

建構隨機梯度下降的最佳化器三元組。

參數:

step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。

傳回:

(init_fun, update_fun, get_params) 三元組。

jax.example_libraries.optimizers.sm3(step_size, momentum=0.9)[source]#

建構 SM3 的最佳化器三元組。

大規模學習的記憶體效率自適應最佳化。https://arxiv.org/abs/1901.11150

參數:
  • step_size – 正純量,或表示將迭代索引對應到正純量的步長排程的可呼叫物件。

  • momentum – 選用,動量的正純量值

傳回:

(init_fun, update_fun, get_params) 三元組。

jax.example_libraries.optimizers.unpack_optimizer_state(opt_state)[source]#

將 OptimizerState 轉換為標記的 pytree。

將 OptimizerState 轉換為標記的 pytree,其中外 pytree 的葉子表示為 JoinPoint,以避免遺失資訊。 此函式旨在於序列化最佳化器狀態時很有用。

參數:

opt_state – OptimizerState

傳回:

具有 JoinPoint 葉子的 pytree,其中包含第二層 pytree。