Skip to content

Commit

Permalink
fixed issue that warmup baselines alpha could become larger than 1
Browse files Browse the repository at this point in the history
  • Loading branch information
LTluttmann committed May 31, 2023
1 parent 6dbad47 commit e282eff
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions reinforce_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ def eval(self, x, c):
v, l = self.baseline.eval(x, c)
vw, lw = self.warmup_baseline.eval(x, c)
# Return convex combination of baseline and of loss
return self.alpha * v + (1 - self.alpha) * vw, self.alpha * l + (1 - self.alpha * lw)
return self.alpha * v + (1 - self.alpha) * vw, self.alpha * l + (1 - self.alpha) * lw

def epoch_callback(self, model, epoch):
# Need to call epoch callback of inner model (also after first epoch if we have not used it)
self.baseline.epoch_callback(model, epoch)
self.alpha = (epoch + 1) / float(self.n_epochs)
if epoch < self.n_epochs:
self.alpha = (epoch + 1) / float(self.n_epochs)
print("Set warmup alpha = {}".format(self.alpha))

def state_dict(self):
Expand Down

0 comments on commit e282eff

Please sign in to comment.