Coverage for src/deepdraw/engine/adabound.py: 16%
132 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-03-29 22:17 +0100
« prev ^ index » next coverage.py v7.4.2, created at 2024-03-29 22:17 +0100
1# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
2#
3# SPDX-License-Identifier: GPL-3.0-or-later
5"""Implementation of the AdaBound optimizer.
7<https://github.com/Luolc/AdaBound/blob/master/adabound/adabound.py>::
9 @inproceedings{Luo2019AdaBound,
10 author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu},
11 title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate},
12 booktitle = {Proceedings of the 7th International Conference on Learning Representations},
13 month = {May},
14 year = {2019},
15 address = {New Orleans, Louisiana}
16 }
17"""
19import math
21import torch
22import torch.optim
25class AdaBound(torch.optim.Optimizer):
26 """Implements the AdaBound algorithm.
28 Parameters
29 ----------
31 params : list
32 Iterable of parameters to optimize or dicts defining parameter groups
34 lr : :obj:`float`, optional
35 Adam learning rate
37 betas : :obj:`tuple`, optional
38 Coefficients (as a 2-tuple of floats) used for computing running
39 averages of gradient and its square
41 final_lr : :obj:`float`, optional
42 Final (SGD) learning rate
44 gamma : :obj:`float`, optional
45 Convergence speed of the bound functions
47 eps : :obj:`float`, optional
48 Term added to the denominator to improve numerical stability
50 weight_decay : :obj:`float`, optional
51 Weight decay (L2 penalty)
53 amsbound : :obj:`bool`, optional
54 Whether to use the AMSBound variant of this algorithm
55 """
57 def __init__(
58 self,
59 params,
60 lr=1e-3,
61 betas=(0.9, 0.999),
62 final_lr=0.1,
63 gamma=1e-3,
64 eps=1e-8,
65 weight_decay=0,
66 amsbound=False,
67 ):
68 if not 0.0 <= lr:
69 raise ValueError(f"Invalid learning rate: {lr}")
70 if not 0.0 <= eps:
71 raise ValueError(f"Invalid epsilon value: {eps}")
72 if not 0.0 <= betas[0] < 1.0:
73 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
74 if not 0.0 <= betas[1] < 1.0:
75 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
76 if not 0.0 <= final_lr:
77 raise ValueError(f"Invalid final learning rate: {final_lr}")
78 if not 0.0 <= gamma < 1.0:
79 raise ValueError(f"Invalid gamma parameter: {gamma}")
80 defaults = dict(
81 lr=lr,
82 betas=betas,
83 final_lr=final_lr,
84 gamma=gamma,
85 eps=eps,
86 weight_decay=weight_decay,
87 amsbound=amsbound,
88 )
89 super().__init__(params, defaults)
91 self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
93 def __setstate__(self, state):
94 super().__setstate__(state)
95 for group in self.param_groups:
96 group.setdefault("amsbound", False)
98 def step(self, closure=None):
99 """Performs a single optimization step.
101 Parameters
102 ----------
104 closure : :obj:`callable`, optional
105 A closure that reevaluates the model and returns the loss.
106 """
107 loss = None
108 if closure is not None:
109 loss = closure()
111 for group, base_lr in zip(self.param_groups, self.base_lrs):
112 for p in group["params"]:
113 if p.grad is None:
114 continue
115 grad = p.grad.data
116 if grad.is_sparse:
117 raise RuntimeError(
118 "Adam does not support sparse gradients, please consider SparseAdam instead"
119 )
120 amsbound = group["amsbound"]
122 state = self.state[p]
124 # State initialization
125 if len(state) == 0:
126 state["step"] = 0
127 # Exponential moving average of gradient values
128 state["exp_avg"] = torch.zeros_like(p.data)
129 # Exponential moving average of squared gradient values
130 state["exp_avg_sq"] = torch.zeros_like(p.data)
131 if amsbound:
132 # Maintains max of all exp. moving avg. of sq. grad. values
133 state["max_exp_avg_sq"] = torch.zeros_like(p.data)
135 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
136 if amsbound:
137 max_exp_avg_sq = state["max_exp_avg_sq"]
138 beta1, beta2 = group["betas"]
140 state["step"] += 1
142 if group["weight_decay"] != 0:
143 grad = grad.add(group["weight_decay"], p.data)
145 # Decay the first and second moment running average coefficient
146 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
147 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
148 if amsbound:
149 # Maintains the maximum of all 2nd moment running avg. till now
150 torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
151 # Use the max. for normalizing running avg. of gradient
152 denom = max_exp_avg_sq.sqrt().add_(group["eps"])
153 else:
154 denom = exp_avg_sq.sqrt().add_(group["eps"])
156 bias_correction1 = 1 - beta1 ** state["step"]
157 bias_correction2 = 1 - beta2 ** state["step"]
158 step_size = (
159 group["lr"] * math.sqrt(bias_correction2) / bias_correction1
160 )
162 # Applies bounds on actual learning rate
163 # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
164 final_lr = group["final_lr"] * group["lr"] / base_lr
165 lower_bound = final_lr * (
166 1 - 1 / (group["gamma"] * state["step"] + 1)
167 )
168 upper_bound = final_lr * (
169 1 + 1 / (group["gamma"] * state["step"])
170 )
171 step_size = torch.full_like(denom, step_size)
172 step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(
173 exp_avg
174 )
176 p.data.add_(-step_size)
178 return loss
181class AdaBoundW(torch.optim.Optimizer):
182 """Implements AdaBound algorithm with Decoupled Weight Decay (See
183 https://arxiv.org/abs/1711.05101)
185 Parameters
186 ----------
188 params : list
189 Iterable of parameters to optimize or dicts defining parameter groups
191 lr : :obj:`float`, optional
192 Adam learning rate
194 betas : :obj:`tuple`, optional
195 Coefficients (as a 2-tuple of floats) used for computing running
196 averages of gradient and its square
198 final_lr : :obj:`float`, optional
199 Final (SGD) learning rate
201 gamma : :obj:`float`, optional
202 Convergence speed of the bound functions
204 eps : :obj:`float`, optional
205 Term added to the denominator to improve numerical stability
207 weight_decay : :obj:`float`, optional
208 Weight decay (L2 penalty)
210 amsbound : :obj:`bool`, optional
211 Whether to use the AMSBound variant of this algorithm
212 """
214 def __init__(
215 self,
216 params,
217 lr=1e-3,
218 betas=(0.9, 0.999),
219 final_lr=0.1,
220 gamma=1e-3,
221 eps=1e-8,
222 weight_decay=0,
223 amsbound=False,
224 ):
225 if not 0.0 <= lr:
226 raise ValueError(f"Invalid learning rate: {lr}")
227 if not 0.0 <= eps:
228 raise ValueError(f"Invalid epsilon value: {eps}")
229 if not 0.0 <= betas[0] < 1.0:
230 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
231 if not 0.0 <= betas[1] < 1.0:
232 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
233 if not 0.0 <= final_lr:
234 raise ValueError(f"Invalid final learning rate: {final_lr}")
235 if not 0.0 <= gamma < 1.0:
236 raise ValueError(f"Invalid gamma parameter: {gamma}")
237 defaults = dict(
238 lr=lr,
239 betas=betas,
240 final_lr=final_lr,
241 gamma=gamma,
242 eps=eps,
243 weight_decay=weight_decay,
244 amsbound=amsbound,
245 )
246 super().__init__(params, defaults)
248 self.base_lrs = list(map(lambda group: group["lr"], self.param_groups))
250 def __setstate__(self, state):
251 super().__setstate__(state)
252 for group in self.param_groups:
253 group.setdefault("amsbound", False)
255 def step(self, closure=None):
256 """Performs a single optimization step.
258 Parameters
259 ----------
261 closure : :obj:`callable`, optional
262 A closure that reevaluates the model and returns the loss.
263 """
265 loss = None
266 if closure is not None:
267 loss = closure()
269 for group, base_lr in zip(self.param_groups, self.base_lrs):
270 for p in group["params"]:
271 if p.grad is None:
272 continue
273 grad = p.grad.data
274 if grad.is_sparse:
275 raise RuntimeError(
276 "Adam does not support sparse gradients, please consider SparseAdam instead"
277 )
278 amsbound = group["amsbound"]
280 state = self.state[p]
282 # State initialization
283 if len(state) == 0:
284 state["step"] = 0
285 # Exponential moving average of gradient values
286 state["exp_avg"] = torch.zeros_like(p.data)
287 # Exponential moving average of squared gradient values
288 state["exp_avg_sq"] = torch.zeros_like(p.data)
289 if amsbound:
290 # Maintains max of all exp. moving avg. of sq. grad. values
291 state["max_exp_avg_sq"] = torch.zeros_like(p.data)
293 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
294 if amsbound:
295 max_exp_avg_sq = state["max_exp_avg_sq"]
296 beta1, beta2 = group["betas"]
298 state["step"] += 1
300 # Decay the first and second moment running average coefficient
301 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
302 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
303 if amsbound:
304 # Maintains the maximum of all 2nd moment running avg. till now
305 torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
306 # Use the max. for normalizing running avg. of gradient
307 denom = max_exp_avg_sq.sqrt().add_(group["eps"])
308 else:
309 denom = exp_avg_sq.sqrt().add_(group["eps"])
311 bias_correction1 = 1 - beta1 ** state["step"]
312 bias_correction2 = 1 - beta2 ** state["step"]
313 step_size = (
314 group["lr"] * math.sqrt(bias_correction2) / bias_correction1
315 )
317 # Applies bounds on actual learning rate
318 # lr_scheduler cannot affect final_lr, this is a workaround to
319 # apply lr decay
320 final_lr = group["final_lr"] * group["lr"] / base_lr
321 lower_bound = final_lr * (
322 1 - 1 / (group["gamma"] * state["step"] + 1)
323 )
324 upper_bound = final_lr * (
325 1 + 1 / (group["gamma"] * state["step"])
326 )
327 step_size = torch.full_like(denom, step_size)
328 step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(
329 exp_avg
330 )
332 if group["weight_decay"] != 0:
333 decayed_weights = torch.mul(p.data, group["weight_decay"])
334 p.data.add_(-step_size)
335 p.data.sub_(decayed_weights)
336 else:
337 p.data.add_(-step_size)
339 return loss