forked from fastai/course22p2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathactivations.py
96 lines (79 loc) · 3.34 KB
/
activations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_activations.ipynb.
# %% ../nbs/10_activations.ipynb 2
from __future__ import annotations
import random,math,torch,numpy as np,matplotlib.pyplot as plt
import fastcore.all as fc
from functools import partial
from .datasets import *
from .learner import *
# %% auto 0
__all__ = ['set_seed', 'Hook', 'Hooks', 'HooksCallback', 'append_stats', 'get_hist', 'get_min', 'ActivationStats']
# %% ../nbs/10_activations.ipynb 4
def set_seed(seed, deterministic=False):
torch.use_deterministic_algorithms(deterministic)
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
# %% ../nbs/10_activations.ipynb 30
class Hook():
def __init__(self, m, f): self.hook = m.register_forward_hook(partial(f, self))
def remove(self): self.hook.remove()
def __del__(self): self.remove()
# %% ../nbs/10_activations.ipynb 42
class Hooks(list):
def __init__(self, ms, f): super().__init__([Hook(m, f) for m in ms])
def __enter__(self, *args): return self
def __exit__ (self, *args): self.remove()
def __del__(self): self.remove()
def __delitem__(self, i):
self[i].remove()
super().__delitem__(i)
def remove(self):
for h in self: h.remove()
# %% ../nbs/10_activations.ipynb 46
class HooksCallback(Callback):
def __init__(self, hookfunc, mod_filter=fc.noop, on_train=True, on_valid=False, mods=None):
fc.store_attr()
super().__init__()
def before_fit(self, learn):
if self.mods: mods=self.mods
else: mods = fc.filter_ex(learn.model.modules(), self.mod_filter)
self.hooks = Hooks(mods, partial(self._hookfunc, learn))
def _hookfunc(self, learn, *args, **kwargs):
if (self.on_train and learn.training) or (self.on_valid and not learn.training): self.hookfunc(*args, **kwargs)
def after_fit(self, learn): self.hooks.remove()
def __iter__(self): return iter(self.hooks)
def __len__(self): return len(self.hooks)
# %% ../nbs/10_activations.ipynb 51
def append_stats(hook, mod, inp, outp):
if not hasattr(hook,'stats'): hook.stats = ([],[],[])
acts = to_cpu(outp)
hook.stats[0].append(acts.mean())
hook.stats[1].append(acts.std())
hook.stats[2].append(acts.abs().histc(40,0,10))
# %% ../nbs/10_activations.ipynb 53
# Thanks to @ste for initial version of histgram plotting code
def get_hist(h): return torch.stack(h.stats[2]).t().float().log1p()
# %% ../nbs/10_activations.ipynb 55
def get_min(h):
h1 = torch.stack(h.stats[2]).t().float()
return h1[0]/h1.sum(0)
# %% ../nbs/10_activations.ipynb 58
class ActivationStats(HooksCallback):
def __init__(self, mod_filter=fc.noop): super().__init__(append_stats, mod_filter)
def color_dim(self, figsize=(11,5)):
fig,axes = get_grid(len(self), figsize=figsize)
for ax,h in zip(axes.flat, self):
show_image(get_hist(h), ax, origin='lower')
def dead_chart(self, figsize=(11,5)):
fig,axes = get_grid(len(self), figsize=figsize)
for ax,h in zip(axes.flatten(), self):
ax.plot(get_min(h))
ax.set_ylim(0,1)
def plot_stats(self, figsize=(10,4)):
fig,axs = plt.subplots(1,2, figsize=figsize)
for h in self:
for i in 0,1: axs[i].plot(h.stats[i])
axs[0].set_title('Means')
axs[1].set_title('Stdevs')
plt.legend(fc.L.range(self))