Diffusion unet

Author

Benedict Thekkel

import os
os.environ['CUDA_VISIBLE_DEVICES']='1'
torch.set_printoptions(precision=4, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray_r'
mpl.rcParams['figure.dpi'] = 70

import logging
logging.disable(logging.WARNING)

set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8
xl,yl = 'image','label'
name = "fashion_mnist"
bs = 512
dsd = load_dataset(name)
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))-0.5 for o in b[xl]]

tds = dsd.with_transform(transformi)
dls = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))

dl = dls.train
(xt,t),eps = b = next(iter(dl))

Train

Based on Diffusers

# This version is giving poor results - use the cell below instead
class SelfAttention(nn.Module):
    def __init__(self, ni, attn_chans):
        super().__init__()
        self.attn = nn.MultiheadAttention(ni, ni//attn_chans, batch_first=True)
        self.norm = nn.BatchNorm2d(ni)

    def forward(self, x):
        n,c,h,w = x.shape
        x = self.norm(x).view(n, c, -1).transpose(1, 2)
        x = self.attn(x, x, x, need_weights=False)[0]
        return x.transpose(1,2).reshape(n,c,h,w)
lr = 1e-2
epochs = 25
opt_func = partial(optim.Adam, eps=1e-5)
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
cbs = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]
model = EmbUNetModel(in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2)
learn = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func)
learn.fit(epochs)
loss epoch train
0.150 0 train
0.086 0 eval
0.069 1 train
0.171 1 eval
0.057 2 train
0.071 2 eval
0.050 3 train
0.055 3 eval
0.045 4 train
0.050 4 eval
0.043 5 train
0.073 5 eval
0.041 6 train
0.044 6 eval
0.039 7 train
0.044 7 eval
0.038 8 train
0.043 8 eval
0.038 9 train
0.058 9 eval
0.038 10 train
0.044 10 eval
0.036 11 train
0.042 11 eval
0.035 12 train
0.038 12 eval
0.035 13 train
0.039 13 eval
0.034 14 train
0.036 14 eval
0.034 15 train
0.036 15 eval
0.034 16 train
0.034 16 eval
0.034 17 train
0.035 17 eval
0.033 18 train
0.033 18 eval
0.033 19 train
0.033 19 eval
0.033 20 train
0.033 20 eval
0.033 21 train
0.032 21 eval
0.032 22 train
0.034 22 eval
0.032 23 train
0.032 23 eval
0.032 24 train
0.033 24 eval

Sampling

from miniai.fid import ImageEval
cmodel = torch.load('models/data_aug2.pkl')
del(cmodel[8])
del(cmodel[7])

@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]]

bs = 2048
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)

dt = dls.train
xb,yb = next(iter(dt))

ie = ImageEval(cmodel, dls, cbs=[DeviceCB()])
sz = (2048,1,32,32)
# set_seed(42)
preds = sample(ddim_step, model, sz, steps=100, eta=1.)
s = (preds[-1]*2)
s.min(),s.max(),s.shape
100.00% [100/100 00:53<00:00]
(tensor(-1.0918), tensor(1.4292), torch.Size([2048, 1, 32, 32]))
show_images(s[:25].clamp(-1,1), imsize=1.5)

ie.fid(s),ie.kid(s),s.shape
(4.058064770194278, 0.010895456187427044, torch.Size([2048, 1, 32, 32]))
preds = sample(ddim_step, model, sz, steps=100, eta=1.)
ie.fid(preds[-1]*2)
100.00% [100/100 00:53<00:00]
5.320260029850715
preds = sample(ddim_step, model, sz, steps=50, eta=1.)
ie.fid(preds[-1]*2)
100.00% [50/50 00:26<00:00]
5.243807277315682
preds = sample(ddim_step, model, sz, steps=50, eta=1.)
ie.fid(preds[-1]*2)
100.00% [50/50 00:26<00:00]
4.963977301033992

Conditional model

def collate_ddpm(b):
    b = default_collate(b)
    (xt,t),eps = noisify(b[xl])
    return (xt,t,b[yl]),eps
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))-0.5 for o in b[xl]]

tds = dsd.with_transform(transformi)
dls = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))

dl = dls.train
(xt,t,c),eps = b = next(iter(dl))
class CondUNetModel(nn.Module):
    def __init__( self, n_classes, in_channels=3, out_channels=3, nfs=(224,448,672,896), num_layers=1):
        super().__init__()
        self.conv_in = nn.Conv2d(in_channels, nfs[0], kernel_size=3, padding=1)
        self.n_temb = nf = nfs[0]
        n_emb = nf*4
        self.cond_emb = nn.Embedding(n_classes, n_emb)
        self.emb_mlp = nn.Sequential(lin(self.n_temb, n_emb, norm=nn.BatchNorm1d),
                                     lin(n_emb, n_emb))
        self.downs = nn.ModuleList()
        for i in range(len(nfs)):
            ni = nf
            nf = nfs[i]
            self.downs.append(DownBlock(n_emb, ni, nf, add_down=i!=len(nfs)-1, num_layers=num_layers))
        self.mid_block = EmbResBlock(n_emb, nfs[-1])

        rev_nfs = list(reversed(nfs))
        nf = rev_nfs[0]
        self.ups = nn.ModuleList()
        for i in range(len(nfs)):
            prev_nf = nf
            nf = rev_nfs[i]
            ni = rev_nfs[min(i+1, len(nfs)-1)]
            self.ups.append(UpBlock(n_emb, ni, prev_nf, nf, add_up=i!=len(nfs)-1, num_layers=num_layers+1))
        self.conv_out = pre_conv(nfs[0], out_channels, act=nn.SiLU, norm=nn.BatchNorm2d, bias=False)

    def forward(self, inp):
        x,t,c = inp
        temb = timestep_embedding(t, self.n_temb)
        cemb = self.cond_emb(c)
        emb = self.emb_mlp(temb) + cemb
        x = self.conv_in(x)
        saved = [x]
        for block in self.downs: x = block(x, emb)
        saved += [p for o in self.downs for p in o.saved]
        x = self.mid_block(x, emb)
        for block in self.ups: x = block(x, emb, saved)
        return self.conv_out(x)
lr = 1e-2
epochs = 25
opt_func = partial(optim.Adam, eps=1e-5)
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
cbs = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]
model = CondUNetModel(10, in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2)
learn = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func)
learn.fit(epochs)
loss epoch train
0.178 0 train
0.099 0 eval
0.072 1 train
0.066 1 eval
0.053 2 train
0.053 2 eval
0.047 3 train
0.050 3 eval
0.045 4 train
0.045 4 eval
0.042 5 train
0.048 5 eval
0.041 6 train
0.060 6 eval
0.039 7 train
0.042 7 eval
0.037 8 train
0.039 8 eval
0.037 9 train
0.051 9 eval
0.036 10 train
0.039 10 eval
0.035 11 train
0.041 11 eval
0.035 12 train
0.041 12 eval
0.034 13 train
0.035 13 eval
0.034 14 train
0.035 14 eval
0.034 15 train
0.036 15 eval
0.033 16 train
0.037 16 eval
0.033 17 train
0.032 17 eval
0.032 18 train
0.036 18 eval
0.032 19 train
0.033 19 eval
0.032 20 train
0.032 20 eval
0.032 21 train
0.033 21 eval
0.032 22 train
0.033 22 eval
0.031 23 train
0.032 23 eval
0.031 24 train
0.033 24 eval

sz = (256,1,32,32)
lbls = dsd['train'].features[yl].names
lbls
['T - shirt / top',
 'Trouser',
 'Pullover',
 'Dress',
 'Coat',
 'Sandal',
 'Shirt',
 'Sneaker',
 'Bag',
 'Ankle boot']
set_seed(42)
cid = 0
preds = sample(cid, ddim_step, model, sz, steps=100, eta=1.)
s = (preds[-1]*2)
show_images(s[:25].clamp(-1,1), imsize=1.5, suptitle=lbls[cid])
100.00% [100/100 00:02<00:00]

set_seed(42)
cid = 0
preds = sample(cid, ddim_step, model, sz, steps=100, eta=0.)
s = (preds[-1]*2)
show_images(s[:25].clamp(-1,1), imsize=1.5, suptitle=lbls[cid])
100.00% [100/100 00:02<00:00]

Back to top