FID

- skip_showdoc: true
Author

Benedict Thekkel

from fastcore.test import test_close
from torch import distributions

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray_r'

import logging
logging.disable(logging.WARNING)

set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8

Classifier

xl,yl = 'image','label'
name = "fashion_mnist"
bs = 512

@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]]

dsd = load_dataset(name)
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)
b = xb,yb = next(iter(dls.train))
cbs = [DeviceCB(), MixedPrecision()]
model = torch.load('models/data_aug2.pkl')
learn = Learner(model, dls, F.cross_entropy, cbs=cbs, opt_func=None)
hcb = HooksCallback(append_outp, mods=[learn.model[6]], on_valid=True)
learn.fit(1, train=False, cbs=[hcb])
feats = hcb.hooks[0].outp[0].float()[:64]
feats.shape
torch.Size([64, 512])
del(learn.model[8])
del(learn.model[7])
feats,y = learn.capture_preds()
feats = feats.float()
feats.shape,y
(torch.Size([10000, 512]), tensor([9, 2, 1,  ..., 8, 1, 5]))

Calc FID

betamin,betamax,n_steps = 0.0001,0.02,1000
beta = torch.linspace(betamin, betamax, n_steps)
alpha = 1.-beta
alphabar = alpha.cumprod(dim=0)
sigma = beta.sqrt()
dls2 = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))
smodel = torch.load('models/fashion_ddpm_mp.pkl').cuda()
samples = sample(smodel, (256, 1, 32, 32), alpha, alphabar, sigma, n_steps)
CPU times: user 1min 47s, sys: 7.21 s, total: 1min 54s
Wall time: 1min 53s
s = samples[-1]*2-1
show_images(s[:16], imsize=1.5)

clearn = TrainLearner(model, DataLoaders([],[(s,yb)]), loss_func=fc.noop, cbs=[DeviceCB()], opt_func=None)
feats2,y2 = clearn.capture_preds()
feats2 = feats2.float().squeeze()
feats2.shape
torch.Size([256, 512])
means = feats.mean(0)
means.shape
torch.Size([512])
covs = feats.T.cov()
covs.shape
torch.Size([512, 512])
s1,s2 = _calc_stats(feats),_calc_stats(feats2)
_calc_fid(*s1, *s2)
33.83489121216962
_calc_kid(feats, feats2)
0.05612194538116455

FID class

ie = ImageEval(model, learn.dls, cbs=[DeviceCB()])
ie.fid(s)
CPU times: user 7.38 s, sys: 234 ms, total: 7.62 s
Wall time: 263 ms
33.90600362686632
ie.kid(s)
CPU times: user 714 ms, sys: 23 ms, total: 737 ms
Wall time: 23 ms
0.0564301423728466
xs = L.range(0,1000,50)+[975,990,999]
plt.plot(xs, [ie.fid(samples[i].clamp(-0.5,0.5)*2) for i in xs]);

xs = L.range(0,1000,50)+[975,990,999]
plt.plot(xs, [ie.kid(samples[i].clamp(-0.5,0.5)*2) for i in xs]);

ie.fid(xb)
6.615052956342197
ie.kid(xb)
-0.02641688659787178

Inception

from pytorch_fid.inception import InceptionV3
a = tensor([1,2,3])
a.repeat((3,1))
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])
class IncepWrap(nn.Module):
    def __init__(self):
        super().__init__()
        self.m = InceptionV3(resize_input=True)
    def forward(self, x): return self.m(x.repeat(1,3,1,1))[0]
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)
ie = ImageEval(IncepWrap(), dls, cbs=[DeviceCB()])
ie.fid(s)
CPU times: user 1min 11s, sys: 1.61 s, total: 1min 13s
Wall time: 2.31 s
63.81579821823857
ie.fid(xb)
27.95811916882883
ie.kid(s)
CPU times: user 7.44 s, sys: 140 ms, total: 7.58 s
Wall time: 255 ms
0.010766863822937012
ie.kid(xb)
-8.697943121660501e-05
Back to top