import os
'CUDA_VISIBLE_DEVICES']='1' os.environ[
Diffusion unet
=4, linewidth=140, sci_mode=False)
torch.set_printoptions(precision1)
torch.manual_seed('image.cmap'] = 'gray_r'
mpl.rcParams['figure.dpi'] = 70
mpl.rcParams[
import logging
logging.disable(logging.WARNING)
42)
set_seed(if fc.defaults.cpus>8: fc.defaults.cpus=8
= 'image','label'
xl,yl = "fashion_mnist"
name = 512
bs = load_dataset(name) dsd
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))-0.5 for o in b[xl]]
= dsd.with_transform(transformi)
tds = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))
dls
= dls.train
dl = b = next(iter(dl)) (xt,t),eps
Train
Based on Diffusers
# This version is giving poor results - use the cell below instead
class SelfAttention(nn.Module):
def __init__(self, ni, attn_chans):
super().__init__()
self.attn = nn.MultiheadAttention(ni, ni//attn_chans, batch_first=True)
self.norm = nn.BatchNorm2d(ni)
def forward(self, x):
= x.shape
n,c,h,w = self.norm(x).view(n, c, -1).transpose(1, 2)
x = self.attn(x, x, x, need_weights=False)[0]
x return x.transpose(1,2).reshape(n,c,h,w)
= 1e-2
lr = 25
epochs = partial(optim.Adam, eps=1e-5)
opt_func = epochs * len(dls.train)
tmax = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
sched = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]
cbs = EmbUNetModel(in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2)
model = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func) learn
learn.fit(epochs)
loss | epoch | train |
---|---|---|
0.150 | 0 | train |
0.086 | 0 | eval |
0.069 | 1 | train |
0.171 | 1 | eval |
0.057 | 2 | train |
0.071 | 2 | eval |
0.050 | 3 | train |
0.055 | 3 | eval |
0.045 | 4 | train |
0.050 | 4 | eval |
0.043 | 5 | train |
0.073 | 5 | eval |
0.041 | 6 | train |
0.044 | 6 | eval |
0.039 | 7 | train |
0.044 | 7 | eval |
0.038 | 8 | train |
0.043 | 8 | eval |
0.038 | 9 | train |
0.058 | 9 | eval |
0.038 | 10 | train |
0.044 | 10 | eval |
0.036 | 11 | train |
0.042 | 11 | eval |
0.035 | 12 | train |
0.038 | 12 | eval |
0.035 | 13 | train |
0.039 | 13 | eval |
0.034 | 14 | train |
0.036 | 14 | eval |
0.034 | 15 | train |
0.036 | 15 | eval |
0.034 | 16 | train |
0.034 | 16 | eval |
0.034 | 17 | train |
0.035 | 17 | eval |
0.033 | 18 | train |
0.033 | 18 | eval |
0.033 | 19 | train |
0.033 | 19 | eval |
0.033 | 20 | train |
0.033 | 20 | eval |
0.033 | 21 | train |
0.032 | 21 | eval |
0.032 | 22 | train |
0.034 | 22 | eval |
0.032 | 23 | train |
0.032 | 23 | eval |
0.032 | 24 | train |
0.033 | 24 | eval |
Sampling
from miniai.fid import ImageEval
= torch.load('models/data_aug2.pkl')
cmodel del(cmodel[8])
del(cmodel[7])
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]]
= 2048
bs = dsd.with_transform(transformi)
tds = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)
dls
= dls.train
dt = next(iter(dt))
xb,yb
= ImageEval(cmodel, dls, cbs=[DeviceCB()]) ie
= (2048,1,32,32) sz
# set_seed(42)
= sample(ddim_step, model, sz, steps=100, eta=1.)
preds = (preds[-1]*2)
s min(),s.max(),s.shape s.
100.00% [100/100 00:53<00:00]
(tensor(-1.0918), tensor(1.4292), torch.Size([2048, 1, 32, 32]))
25].clamp(-1,1), imsize=1.5) show_images(s[:
ie.fid(s),ie.kid(s),s.shape
(4.058064770194278, 0.010895456187427044, torch.Size([2048, 1, 32, 32]))
= sample(ddim_step, model, sz, steps=100, eta=1.)
preds -1]*2) ie.fid(preds[
100.00% [100/100 00:53<00:00]
5.320260029850715
= sample(ddim_step, model, sz, steps=50, eta=1.)
preds -1]*2) ie.fid(preds[
100.00% [50/50 00:26<00:00]
5.243807277315682
= sample(ddim_step, model, sz, steps=50, eta=1.)
preds -1]*2) ie.fid(preds[
100.00% [50/50 00:26<00:00]
4.963977301033992
Conditional model
def collate_ddpm(b):
= default_collate(b)
b = noisify(b[xl])
(xt,t),eps return (xt,t,b[yl]),eps
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))-0.5 for o in b[xl]]
= dsd.with_transform(transformi)
tds = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))
dls
= dls.train
dl = b = next(iter(dl)) (xt,t,c),eps
class CondUNetModel(nn.Module):
def __init__( self, n_classes, in_channels=3, out_channels=3, nfs=(224,448,672,896), num_layers=1):
super().__init__()
self.conv_in = nn.Conv2d(in_channels, nfs[0], kernel_size=3, padding=1)
self.n_temb = nf = nfs[0]
= nf*4
n_emb self.cond_emb = nn.Embedding(n_classes, n_emb)
self.emb_mlp = nn.Sequential(lin(self.n_temb, n_emb, norm=nn.BatchNorm1d),
lin(n_emb, n_emb))self.downs = nn.ModuleList()
for i in range(len(nfs)):
= nf
ni = nfs[i]
nf self.downs.append(DownBlock(n_emb, ni, nf, add_down=i!=len(nfs)-1, num_layers=num_layers))
self.mid_block = EmbResBlock(n_emb, nfs[-1])
= list(reversed(nfs))
rev_nfs = rev_nfs[0]
nf self.ups = nn.ModuleList()
for i in range(len(nfs)):
= nf
prev_nf = rev_nfs[i]
nf = rev_nfs[min(i+1, len(nfs)-1)]
ni self.ups.append(UpBlock(n_emb, ni, prev_nf, nf, add_up=i!=len(nfs)-1, num_layers=num_layers+1))
self.conv_out = pre_conv(nfs[0], out_channels, act=nn.SiLU, norm=nn.BatchNorm2d, bias=False)
def forward(self, inp):
= inp
x,t,c = timestep_embedding(t, self.n_temb)
temb = self.cond_emb(c)
cemb = self.emb_mlp(temb) + cemb
emb = self.conv_in(x)
x = [x]
saved for block in self.downs: x = block(x, emb)
+= [p for o in self.downs for p in o.saved]
saved = self.mid_block(x, emb)
x for block in self.ups: x = block(x, emb, saved)
return self.conv_out(x)
= 1e-2
lr = 25
epochs = partial(optim.Adam, eps=1e-5)
opt_func = epochs * len(dls.train)
tmax = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
sched = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]
cbs = CondUNetModel(10, in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2)
model = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func) learn
learn.fit(epochs)
loss | epoch | train |
---|---|---|
0.178 | 0 | train |
0.099 | 0 | eval |
0.072 | 1 | train |
0.066 | 1 | eval |
0.053 | 2 | train |
0.053 | 2 | eval |
0.047 | 3 | train |
0.050 | 3 | eval |
0.045 | 4 | train |
0.045 | 4 | eval |
0.042 | 5 | train |
0.048 | 5 | eval |
0.041 | 6 | train |
0.060 | 6 | eval |
0.039 | 7 | train |
0.042 | 7 | eval |
0.037 | 8 | train |
0.039 | 8 | eval |
0.037 | 9 | train |
0.051 | 9 | eval |
0.036 | 10 | train |
0.039 | 10 | eval |
0.035 | 11 | train |
0.041 | 11 | eval |
0.035 | 12 | train |
0.041 | 12 | eval |
0.034 | 13 | train |
0.035 | 13 | eval |
0.034 | 14 | train |
0.035 | 14 | eval |
0.034 | 15 | train |
0.036 | 15 | eval |
0.033 | 16 | train |
0.037 | 16 | eval |
0.033 | 17 | train |
0.032 | 17 | eval |
0.032 | 18 | train |
0.036 | 18 | eval |
0.032 | 19 | train |
0.033 | 19 | eval |
0.032 | 20 | train |
0.032 | 20 | eval |
0.032 | 21 | train |
0.033 | 21 | eval |
0.032 | 22 | train |
0.033 | 22 | eval |
0.031 | 23 | train |
0.032 | 23 | eval |
0.031 | 24 | train |
0.033 | 24 | eval |
= (256,1,32,32) sz
= dsd['train'].features[yl].names
lbls lbls
['T - shirt / top',
'Trouser',
'Pullover',
'Dress',
'Coat',
'Sandal',
'Shirt',
'Sneaker',
'Bag',
'Ankle boot']
42)
set_seed(= 0
cid = sample(cid, ddim_step, model, sz, steps=100, eta=1.)
preds = (preds[-1]*2)
s 25].clamp(-1,1), imsize=1.5, suptitle=lbls[cid]) show_images(s[:
100.00% [100/100 00:02<00:00]
42)
set_seed(= 0
cid = sample(cid, ddim_step, model, sz, steps=100, eta=0.)
preds = (preds[-1]*2)
s 25].clamp(-1,1), imsize=1.5, suptitle=lbls[cid]) show_images(s[:
100.00% [100/100 00:02<00:00]