Skip to content

Commit

Permalink
superres
Browse files Browse the repository at this point in the history
  • Loading branch information
jph00 committed Jan 26, 2023
1 parent 0a3c31e commit 9d8748b
Show file tree
Hide file tree
Showing 7 changed files with 1,515 additions and 97 deletions.
12 changes: 8 additions & 4 deletions miniai/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,22 @@ def show_image_batch(self:Learner, max_n=9, cbs=None, **kwargs):

# %% ../nbs/14_augment.ipynb 38
class CapturePreds(Callback):
def before_fit(self, learn): self.all_preds,self.all_targs = [],[]
def before_fit(self, learn): self.all_inps,self.all_preds,self.all_targs = [],[],[]
def after_batch(self, learn):
self.all_inps. append(to_cpu(learn.batch[0]))
self.all_preds.append(to_cpu(learn.preds))
self.all_targs.append(to_cpu(learn.batch[1]))
def after_fit(self, learn): self.all_preds,self.all_targs = torch.cat(self.all_preds),torch.cat(self.all_targs)
def after_fit(self, learn):
self.all_preds,self.all_targs,self.all_inps = map(torch.cat, [self.all_preds,self.all_targs,self.all_inps])

# %% ../nbs/14_augment.ipynb 39
@fc.patch
def capture_preds(self: Learner, cbs=None):
def capture_preds(self: Learner, cbs=None, inps=False):
cp = CapturePreds()
self.fit(1, train=False, cbs=[cp]+fc.L(cbs))
return cp.all_preds,cp.all_targs
res = cp.all_preds,cp.all_targs
if inps: res = res+(cp.all_inps,)
return res

# %% ../nbs/14_augment.ipynb 54
def _rand_erase1(x, pct, xm, xs, mn, mx):
Expand Down
1 change: 1 addition & 0 deletions miniai/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
__all__ = ['def_device', 'conv', 'to_device', 'collate_device']

# %% ../nbs/07_convolutions.ipynb 2
import torch
from torch import nn

from torch.utils.data import default_collate
Expand Down
1 change: 1 addition & 0 deletions nbs/07_convolutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"outputs": [],
"source": [
"#|export\n",
"import torch\n",
"from torch import nn\n",
"\n",
"from torch.utils.data import default_collate\n",
Expand Down
12 changes: 8 additions & 4 deletions nbs/14_augment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1418,11 +1418,13 @@
"source": [
"#| export\n",
"class CapturePreds(Callback):\n",
" def before_fit(self, learn): self.all_preds,self.all_targs = [],[]\n",
" def before_fit(self, learn): self.all_inps,self.all_preds,self.all_targs = [],[],[]\n",
" def after_batch(self, learn):\n",
" self.all_inps. append(to_cpu(learn.batch[0]))\n",
" self.all_preds.append(to_cpu(learn.preds))\n",
" self.all_targs.append(to_cpu(learn.batch[1]))\n",
" def after_fit(self, learn): self.all_preds,self.all_targs = torch.cat(self.all_preds),torch.cat(self.all_targs)"
" def after_fit(self, learn):\n",
" self.all_preds,self.all_targs,self.all_inps = map(torch.cat, [self.all_preds,self.all_targs,self.all_inps])"
]
},
{
Expand All @@ -1433,10 +1435,12 @@
"source": [
"#| export\n",
"@fc.patch\n",
"def capture_preds(self: Learner, cbs=None):\n",
"def capture_preds(self: Learner, cbs=None, inps=False):\n",
" cp = CapturePreds()\n",
" self.fit(1, train=False, cbs=[cp]+fc.L(cbs))\n",
" return cp.all_preds,cp.all_targs"
" res = cp.all_preds,cp.all_targs\n",
" if inps: res = res+(cp.all_inps,)\n",
" return res"
]
},
{
Expand Down
12 changes: 4 additions & 8 deletions nbs/22_cosine.ipynb

Large diffs are not rendered by default.

155 changes: 74 additions & 81 deletions nbs/23_karras.ipynb

Large diffs are not rendered by default.

1,419 changes: 1,419 additions & 0 deletions nbs/25_superres.ipynb

Large diffs are not rendered by default.

0 comments on commit 9d8748b

Please sign in to comment.