import pickle,gzip,math,os,time,shutil,torch,random,logging
import fastcore.all as fc,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
from collections.abc import Mapping
from pathlib import Path
from functools import partial
from fastcore.foundation import L
import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torch.nn import init
from torch.optim import lr_scheduler
from fastAIcourse.datasets import *
from fastAIcourse.conv import *
from fastAIcourse.learner import *
from fastAIcourse.activations import *
from fastAIcourse.init import *
from fastAIcourse.sgd import *
from fastAIcourse.resnet import *
from fastAIcourse.augment import *
from fastAIcourse.accel import *
from fastAIcourse.fid import *
Denoising Diffusion Implicit Models - DDIM
Denoising Diffusion Implicit Models - DDIM
from fastprogress.fastprogress import progress_bar
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder
'image.cmap'] = 'gray_r'
mpl.rcParams[ logging.disable(logging.WARNING)
= 'image','label'
xl,yl = "fashion_mnist"
name = load_dataset(name) dsd
from diffusers import UNet2DModel, DDIMPipeline, DDPMPipeline, DDIMScheduler, DDPMScheduler
Diffusers DDPM Scheduler
class UNet(UNet2DModel): pass
= torch.load('models/fashion_ddpm3_25.pkl').cuda()
model # model = torch.load('models/fashion_no-t.pkl').cuda()
= DDPMScheduler(beta_end=0.01) sched
= torch.randn((4,1,32,32)).cuda() x_t
= 99
t = torch.full((len(x_t),), t, device=x_t.device, dtype=torch.long)
t_batch with torch.no_grad(): noise = model(x_t, t_batch).sample
= sched.step(noise, t, x_t) res
res.prev_sample.shape
torch.Size([4, 1, 32, 32])
= (512,1,32,32) sz
= torch.randn(sz).cuda()
x_t = []
preds
for t in progress_bar(sched.timesteps):
with torch.no_grad(): noise = model(x_t, t).sample
= sched.step(noise, t, x_t).prev_sample
x_t float().cpu()) preds.append(x_t.
100.00% [1000/1000 03:39<00:00]
CPU times: user 3min 34s, sys: 5.14 s, total: 3min 40s
Wall time: 3min 39s
= preds[-1].clamp(-0.5,0.5)*2 s
25], imsize=1.5) show_images(s[:
= torch.load('models/data_aug2.pkl')
cmodel del(cmodel[8])
del(cmodel[7])
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]]
= 2048
bs = dsd.with_transform(transformi)
tds = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)
dls
= dls.train
dt = next(iter(dt))
xb,yb
= ImageEval(cmodel, dls, cbs=[DeviceCB()]) ie
ie.fid(s),ie.kid(s)
(30.25244140625, 0.07350349426269531)
ie.fid(xb),ie.kid(xb)
(1.296875, -0.00011661575990729034)
Diffusers DDIM Scheduler
= DDIMScheduler(beta_end=0.01)
sched 333) sched.set_timesteps(
def diff_sample(model, sz, sched, **kwargs):
= torch.randn(sz).cuda()
x_t = []
preds for t in progress_bar(sched.timesteps):
with torch.no_grad(): noise = model(x_t, t).sample
= sched.step(noise, t, x_t, **kwargs).prev_sample
x_t float().cpu())
preds.append(x_t.return preds
= diff_sample(model, sz, sched, eta=1.)
preds = (preds[-1]*2).clamp(-1,1) s
100.00% [333/333 01:11<00:00]
25], imsize=1.5) show_images(s[:
ie.fid(s),ie.kid(s)
(36.6961669921875, 0.1079036071896553)
200)
sched.set_timesteps(= diff_sample(model, sz, sched, eta=1.)
preds = (preds[-1]*2).clamp(-1,1)
s ie.fid(s),ie.kid(s)
100.00% [200/200 00:42<00:00]
(33.8856201171875, 0.16268639266490936)
25], imsize=1.5) show_images(s[:
100)
sched.set_timesteps(= diff_sample(model, sz, sched, eta=1.) preds
100.00% [100/100 00:21<00:00]
= (preds[-1]*2).clamp(-1,1) s
ie.fid(s),ie.kid(s)
(35.12646484375, 0.10706407576799393)
25], imsize=1.5) show_images(s[:
50)
sched.set_timesteps(= diff_sample(model, sz, sched, eta=1.)
preds = (preds[-1]*2).clamp(-1,1)
s ie.fid(s),ie.kid(s)
100.00% [50/50 00:11<00:00]
(43.124267578125, 0.21125255525112152)
25], imsize=1.5) show_images(s[:
25)
sched.set_timesteps(= diff_sample(model, sz, sched, eta=1.)
preds = (preds[-1]*2).clamp(-1,1)
s ie.fid(s),ie.kid(s)
100.00% [25/25 00:05<00:00]
(51.3458251953125, 0.21316410601139069)
25], imsize=1.5) show_images(s[:
Implementing DDIM
from types import SimpleNamespace
=1000 n_steps
def linear_sched(betamin=0.0001,betamax=0.02,n_steps=1000):
= torch.linspace(betamin, betamax, n_steps)
beta return SimpleNamespace(a=1.-beta, abar=(1.-beta).cumprod(dim=0), sig=beta.sqrt())
= linear_sched(betamax=0.01, n_steps=n_steps)
sc = sc.abar abar
def ddim_step(x_t, t, noise, abar_t, abar_t1, bbar_t, bbar_t1, eta):
= ((bbar_t1/bbar_t) * (1-abar_t/abar_t1))
vari = vari.sqrt()*eta
sig = ((x_t-bbar_t.sqrt()*noise) / abar_t.sqrt())
x_0_hat = abar_t1.sqrt()*x_0_hat + (bbar_t1-sig**2).sqrt()*noise
x_t if t>0: x_t += sig * torch.randn(x_t.shape).to(x_t)
return x_t
@torch.no_grad()
def sample(f, model, sz, n_steps, skips=1, eta=1.):
= list(reversed(range(0, n_steps, skips)))
tsteps = torch.randn(sz).to(model.device)
x_t = []
preds for i,t in enumerate(progress_bar(tsteps)):
= abar[tsteps[i+1]] if t > 0 else torch.tensor(1)
abar_t1 = model(x_t,t).sample
noise = f(x_t, t, noise, abar[t], abar_t1, 1-abar[t], 1-abar_t1, eta)
x_t float().cpu())
preds.append(x_t.return preds
= sample(ddim_step, model, sz, 1000, 10) samples
100.00% [100/100 00:22<00:00]
CPU times: user 21.7 s, sys: 450 ms, total: 22.2 s
Wall time: 22.2 s
= (samples[-1]*2)#.clamp(-1,1)
s 25], imsize=1.5) show_images(s[:
ie.fid(s),ie.kid(s)
(34.0198974609375, 0.11981192231178284)
Triangular noise
def noisify(x0, ᾱ):
= x0.device
device = len(x0)
n = torch.randint(0, n_steps, (n,), dtype=torch.long)
t = np.random.triangular(0, 0.5, 1, (n,))*n_steps
t = tensor(t, dtype=torch.long)
t = torch.randn(x0.shape, device=device)
ε = ᾱ[t].reshape(-1, 1, 1, 1).to(device)
ᾱ_t = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε
xt return (xt, t.to(device)), ε
= noisify(xb,abar)
(xt,t),ε ; plt.hist(t)