from fastcore.test import test_close
from torch import distributions
=2, linewidth=140, sci_mode=False)
torch.set_printoptions(precision1)
torch.manual_seed('image.cmap'] = 'gray_r'
mpl.rcParams[
import logging
logging.disable(logging.WARNING)
42)
set_seed(if fc.defaults.cpus>8: fc.defaults.cpus=8
FID
- skip_showdoc: true
Classifier
= 'image','label'
xl,yl = "fashion_mnist"
name = 512
bs
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]]
= load_dataset(name)
dsd = dsd.with_transform(transformi)
tds = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus) dls
= xb,yb = next(iter(dls.train)) b
= [DeviceCB(), MixedPrecision()]
cbs = torch.load('models/data_aug2.pkl')
model = Learner(model, dls, F.cross_entropy, cbs=cbs, opt_func=None) learn
= HooksCallback(append_outp, mods=[learn.model[6]], on_valid=True) hcb
1, train=False, cbs=[hcb]) learn.fit(
= hcb.hooks[0].outp[0].float()[:64]
feats feats.shape
torch.Size([64, 512])
del(learn.model[8])
del(learn.model[7])
= learn.capture_preds()
feats,y = feats.float()
feats feats.shape,y
(torch.Size([10000, 512]), tensor([9, 2, 1, ..., 8, 1, 5]))
Calc FID
= 0.0001,0.02,1000
betamin,betamax,n_steps = torch.linspace(betamin, betamax, n_steps)
beta = 1.-beta
alpha = alpha.cumprod(dim=0)
alphabar = beta.sqrt() sigma
= DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test'])) dls2
= torch.load('models/fashion_ddpm_mp.pkl').cuda() smodel
= sample(smodel, (256, 1, 32, 32), alpha, alphabar, sigma, n_steps) samples
CPU times: user 1min 47s, sys: 7.21 s, total: 1min 54s
Wall time: 1min 53s
= samples[-1]*2-1 s
16], imsize=1.5) show_images(s[:
= TrainLearner(model, DataLoaders([],[(s,yb)]), loss_func=fc.noop, cbs=[DeviceCB()], opt_func=None)
clearn = clearn.capture_preds()
feats2,y2 = feats2.float().squeeze()
feats2 feats2.shape
torch.Size([256, 512])
= feats.mean(0)
means means.shape
torch.Size([512])
= feats.T.cov()
covs covs.shape
torch.Size([512, 512])
= _calc_stats(feats),_calc_stats(feats2) s1,s2
*s1, *s2) _calc_fid(
33.83489121216962
_calc_kid(feats, feats2)
0.05612194538116455
FID class
= ImageEval(model, learn.dls, cbs=[DeviceCB()]) ie
ie.fid(s)
CPU times: user 7.38 s, sys: 234 ms, total: 7.62 s
Wall time: 263 ms
33.90600362686632
ie.kid(s)
CPU times: user 714 ms, sys: 23 ms, total: 737 ms
Wall time: 23 ms
0.0564301423728466
= L.range(0,1000,50)+[975,990,999]
xs -0.5,0.5)*2) for i in xs]); plt.plot(xs, [ie.fid(samples[i].clamp(
= L.range(0,1000,50)+[975,990,999]
xs -0.5,0.5)*2) for i in xs]); plt.plot(xs, [ie.kid(samples[i].clamp(
ie.fid(xb)
6.615052956342197
ie.kid(xb)
-0.02641688659787178
Inception
from pytorch_fid.inception import InceptionV3
= tensor([1,2,3])
a 3,1)) a.repeat((
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
class IncepWrap(nn.Module):
def __init__(self):
super().__init__()
self.m = InceptionV3(resize_input=True)
def forward(self, x): return self.m(x.repeat(1,3,1,1))[0]
= dsd.with_transform(transformi)
tds = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus) dls
= ImageEval(IncepWrap(), dls, cbs=[DeviceCB()]) ie
ie.fid(s)
CPU times: user 1min 11s, sys: 1.61 s, total: 1min 13s
Wall time: 2.31 s
63.81579821823857
ie.fid(xb)
27.95811916882883
ie.kid(s)
CPU times: user 7.44 s, sys: 140 ms, total: 7.58 s
Wall time: 255 ms
0.010766863822937012
ie.kid(xb)
-8.697943121660501e-05