Skip to content

Commit

Permalink
add val loss plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
tmabraham committed Apr 10, 2023
1 parent f57bdf5 commit 39d7734
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 63 deletions.
35 changes: 2 additions & 33 deletions miniai/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,38 +122,6 @@
'miniai.diffusion.timestep_embedding': ( 'diffusion-attn-cond.html#timestep_embedding',
'miniai/diffusion.py'),
'miniai.diffusion.upsample': ('diffusion-attn-cond.html#upsample', 'miniai/diffusion.py')},
'miniai.diffusion2': { 'miniai.diffusion2.DownBlock': ('diffusion-attn-nodownsave.html#downblock', 'miniai/diffusion2.py'),
'miniai.diffusion2.DownBlock.__init__': ( 'diffusion-attn-nodownsave.html#downblock.__init__',
'miniai/diffusion2.py'),
'miniai.diffusion2.DownBlock.forward': ( 'diffusion-attn-nodownsave.html#downblock.forward',
'miniai/diffusion2.py'),
'miniai.diffusion2.EmbResBlock': ('diffusion-attn-nodownsave.html#embresblock', 'miniai/diffusion2.py'),
'miniai.diffusion2.EmbResBlock.__init__': ( 'diffusion-attn-nodownsave.html#embresblock.__init__',
'miniai/diffusion2.py'),
'miniai.diffusion2.EmbResBlock.forward': ( 'diffusion-attn-nodownsave.html#embresblock.forward',
'miniai/diffusion2.py'),
'miniai.diffusion2.EmbUNetModel': ( 'diffusion-attn-nodownsave.html#embunetmodel',
'miniai/diffusion2.py'),
'miniai.diffusion2.EmbUNetModel.__init__': ( 'diffusion-attn-nodownsave.html#embunetmodel.__init__',
'miniai/diffusion2.py'),
'miniai.diffusion2.EmbUNetModel.forward': ( 'diffusion-attn-nodownsave.html#embunetmodel.forward',
'miniai/diffusion2.py'),
'miniai.diffusion2.UpBlock': ('diffusion-attn-nodownsave.html#upblock', 'miniai/diffusion2.py'),
'miniai.diffusion2.UpBlock.__init__': ( 'diffusion-attn-nodownsave.html#upblock.__init__',
'miniai/diffusion2.py'),
'miniai.diffusion2.UpBlock.forward': ( 'diffusion-attn-nodownsave.html#upblock.forward',
'miniai/diffusion2.py'),
'miniai.diffusion2.abar': ('diffusion-attn-nodownsave.html#abar', 'miniai/diffusion2.py'),
'miniai.diffusion2.ddim_step': ('diffusion-attn-nodownsave.html#ddim_step', 'miniai/diffusion2.py'),
'miniai.diffusion2.inv_abar': ('diffusion-attn-nodownsave.html#inv_abar', 'miniai/diffusion2.py'),
'miniai.diffusion2.lin': ('diffusion-attn-nodownsave.html#lin', 'miniai/diffusion2.py'),
'miniai.diffusion2.noisify': ('diffusion-attn-nodownsave.html#noisify', 'miniai/diffusion2.py'),
'miniai.diffusion2.pre_conv': ('diffusion-attn-nodownsave.html#pre_conv', 'miniai/diffusion2.py'),
'miniai.diffusion2.sample': ('diffusion-attn-nodownsave.html#sample', 'miniai/diffusion2.py'),
'miniai.diffusion2.saved': ('diffusion-attn-nodownsave.html#saved', 'miniai/diffusion2.py'),
'miniai.diffusion2.timestep_embedding': ( 'diffusion-attn-nodownsave.html#timestep_embedding',
'miniai/diffusion2.py'),
'miniai.diffusion2.upsample': ('diffusion-attn-nodownsave.html#upsample', 'miniai/diffusion2.py')},
'miniai.fid': { 'miniai.fid.ImageEval': ('fid.html#imageeval', 'miniai/fid.py'),
'miniai.fid.ImageEval.__init__': ('fid.html#imageeval.__init__', 'miniai/fid.py'),
'miniai.fid.ImageEval.fid': ('fid.html#imageeval.fid', 'miniai/fid.py'),
Expand Down Expand Up @@ -218,6 +186,7 @@
'miniai.learner.ProgressCB.__init__': ('learner.html#progresscb.__init__', 'miniai/learner.py'),
'miniai.learner.ProgressCB._log': ('learner.html#progresscb._log', 'miniai/learner.py'),
'miniai.learner.ProgressCB.after_batch': ('learner.html#progresscb.after_batch', 'miniai/learner.py'),
'miniai.learner.ProgressCB.after_epoch': ('learner.html#progresscb.after_epoch', 'miniai/learner.py'),
'miniai.learner.ProgressCB.before_epoch': ('learner.html#progresscb.before_epoch', 'miniai/learner.py'),
'miniai.learner.ProgressCB.before_fit': ('learner.html#progresscb.before_fit', 'miniai/learner.py'),
'miniai.learner.SingleBatchCB': ('learner.html#singlebatchcb', 'miniai/learner.py'),
Expand Down Expand Up @@ -270,4 +239,4 @@
'miniai.training.accuracy': ('minibatch_training.html#accuracy', 'miniai/training.py'),
'miniai.training.fit': ('minibatch_training.html#fit', 'miniai/training.py'),
'miniai.training.get_dls': ('minibatch_training.html#get_dls', 'miniai/training.py'),
'miniai.training.report': ('minibatch_training.html#report', 'miniai/training.py')}}}
'miniai.training.report': ('minibatch_training.html#report', 'miniai/training.py')}}}
23 changes: 15 additions & 8 deletions miniai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def before_fit(self, learn):
self.first = True
if hasattr(learn, 'metrics'): learn.metrics._log = self._log
self.losses = []
self.val_losses = []

def _log(self, d):
if self.first:
Expand All @@ -110,9 +111,15 @@ def after_batch(self, learn):
learn.dl.comment = f'{learn.loss:.3f}'
if self.plot and hasattr(learn, 'metrics') and learn.training:
self.losses.append(learn.loss.item())
self.mbar.update_graph([[fc.L.range(self.losses), self.losses]])
if self.val_losses: self.mbar.update_graph([[fc.L.range(self.losses), self.losses],[fc.L.range(learn.epoch).map(lambda x: (x+1)*len(learn.dls.train)), self.val_losses]])

def after_epoch(self, learn):
if not learn.training:
if self.plot and hasattr(learn, 'metrics'):
self.val_losses.append(learn.metrics.all_metrics['loss'].compute())
self.mbar.update_graph([[fc.L.range(self.losses), self.losses],[fc.L.range(learn.epoch+1).map(lambda x: (x+1)*len(learn.dls.train)), self.val_losses]])

# %% ../nbs/09_learner.ipynb 47
# %% ../nbs/09_learner.ipynb 48
class with_cbs:
def __init__(self, nm): self.nm = nm
def __call__(self, f):
Expand All @@ -125,7 +132,7 @@ def _f(o, *args, **kwargs):
finally: o.callback(f'cleanup_{self.nm}')
return _f

# %% ../nbs/09_learner.ipynb 48
# %% ../nbs/09_learner.ipynb 49
class Learner():
def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD):
cbs = fc.L(cbs)
Expand Down Expand Up @@ -181,15 +188,15 @@ def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)
@property
def training(self): return self.model.training

# %% ../nbs/09_learner.ipynb 51
# %% ../nbs/09_learner.ipynb 52
class TrainLearner(Learner):
def predict(self): self.preds = self.model(self.batch[0])
def get_loss(self): self.loss = self.loss_func(self.preds, self.batch[1])
def backward(self): self.loss.backward()
def step(self): self.opt.step()
def zero_grad(self): self.opt.zero_grad()

# %% ../nbs/09_learner.ipynb 52
# %% ../nbs/09_learner.ipynb 53
class MomentumLearner(TrainLearner):
def __init__(self, model, dls, loss_func, lr=None, cbs=None, opt_func=optim.SGD, mom=0.85):
self.mom = mom
Expand All @@ -199,10 +206,10 @@ def zero_grad(self):
with torch.no_grad():
for p in self.model.parameters(): p.grad *= self.mom

# %% ../nbs/09_learner.ipynb 57
# %% ../nbs/09_learner.ipynb 58
from torch.optim.lr_scheduler import ExponentialLR

# %% ../nbs/09_learner.ipynb 59
# %% ../nbs/09_learner.ipynb 60
class LRFinderCB(Callback):
def __init__(self, gamma=1.3, max_mult=3): fc.store_attr()

Expand All @@ -225,7 +232,7 @@ def cleanup_fit(self, learn):
plt.plot(self.lrs, self.losses)
plt.xscale('log')

# %% ../nbs/09_learner.ipynb 61
# %% ../nbs/09_learner.ipynb 62
@fc.patch
def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10):
self.fit(max_epochs, lr=start_lr, cbs=LRFinderCB(gamma=gamma, max_mult=max_mult))
Loading

0 comments on commit 39d7734

Please sign in to comment.