diff --git a/adabound/adabound.py b/adabound/adabound.py index c1d58e2..c8327fa 100644 --- a/adabound/adabound.py +++ b/adabound/adabound.py @@ -88,11 +88,11 @@ def step(self, closure=None): state['step'] += 1 if group['weight_decay'] != 0: - grad = grad.add(group['weight_decay'], p.data) + grad = grad.add(p.data, alpha=group['weight_decay']) # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(1 - beta1, grad) - exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if amsbound: # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)