forked from fastai/course22p2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfid.py
95 lines (82 loc) · 3.09 KB
/
fid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/18_fid.ipynb.
# %% auto 0
__all__ = ['ImageEval']
# %% ../nbs/18_fid.ipynb 2
import pickle,gzip,math,os,time,shutil,torch,random
import fastcore.all as fc,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from copy import copy
from contextlib import contextmanager
from scipy import linalg
from fastcore.foundation import L
import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torch.nn import init
from torch.optim import lr_scheduler
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder
from .datasets import *
from .conv import *
from .learner import *
from .activations import *
from .init import *
from .sgd import *
from .resnet import *
from .augment import *
from .accel import *
# %% ../nbs/18_fid.ipynb 27
def _sqrtm_newton_schulz(mat, num_iters=100):
mat_nrm = mat.norm()
mat = mat.double()
Y = mat/mat_nrm
n = len(mat)
I = torch.eye(n, n).to(mat)
Z = torch.eye(n, n).to(mat)
for i in range(num_iters):
T = (3*I - Z@Y)/2
Y,Z = Y@T,T@Z
res = Y*mat_nrm.sqrt()
if ((mat-(res@res)).norm()/mat_nrm).abs()<=1e-6: break
return res
# %% ../nbs/18_fid.ipynb 28
def _calc_stats(feats):
feats = feats.squeeze()
return feats.mean(0),feats.T.cov()
def _calc_fid(m1,c1,m2,c2):
# csr = _sqrtm_newton_schulz(c1@c2)
csr = tensor(linalg.sqrtm(c1@c2, 256).real)
return (((m1-m2)**2).sum() + c1.trace() + c2.trace() - 2*csr.trace()).item()
# %% ../nbs/18_fid.ipynb 31
def _squared_mmd(x, y):
def k(a,b): return ([email protected](-2,-1)/a.shape[-1]+1)**3
m,n = x.shape[-2],y.shape[-2]
kxx,kyy,kxy = k(x,x), k(y,y), k(x,y)
kxx_sum = kxx.sum([-1,-2])-kxx.diagonal(0,-1,-2).sum(-1)
kyy_sum = kyy.sum([-1,-2])-kyy.diagonal(0,-1,-2).sum(-1)
kxy_sum = kxy.sum([-1,-2])
return kxx_sum/m/(m-1) + kyy_sum/n/(n-1) - kxy_sum*2/m/n
# %% ../nbs/18_fid.ipynb 32
def _calc_kid(x, y, maxs=50):
xs,ys = x.shape[0],y.shape[0]
n = max(math.ceil(min(xs/maxs, ys/maxs)), 4)
mmd = 0.
for i in range(n):
cur_x = x[round(i*xs/n) : round((i+1)*xs/n)]
cur_y = y[round(i*ys/n) : round((i+1)*ys/n)]
mmd += _squared_mmd(cur_x, cur_y)
return (mmd/n).item()
# %% ../nbs/18_fid.ipynb 35
class ImageEval:
def __init__(self, model, dls, cbs=None):
self.learn = TrainLearner(model, dls, loss_func=fc.noop, cbs=cbs, opt_func=None)
self.feats = self.learn.capture_preds()[0].float().cpu().squeeze()
self.stats = _calc_stats(self.feats)
def get_feats(self, samp):
self.learn.dls = DataLoaders([],[(samp, tensor([0]))])
return self.learn.capture_preds()[0].float().cpu().squeeze()
def fid(self, samp): return _calc_fid(*self.stats, *_calc_stats(self.get_feats(samp)))
def kid(self, samp): return _calc_kid(self.feats, self.get_feats(samp))