import os
# os.environ['CUDA_VISIBLE_DEVICES']='1'
Diffusion unet
import timm, torch, random, datasets, math, fastcore.all as fc, numpy as np, matplotlib as mpl, matplotlib.pyplot as plt
import k_diffusion as K, torchvision.transforms as T
import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch.utils.data import DataLoader,default_collate
from pathlib import Path
from torch.nn import init
from fastcore.foundation import L
from torch import nn,tensor
from datasets import load_dataset
from operator import itemgetter
from torcheval.metrics import MulticlassAccuracy
from functools import partial
from torch.optim import lr_scheduler
from torch import optim
from fastAIcourse.datasets import *
from fastAIcourse.conv import *
from fastAIcourse.learner import *
from fastAIcourse.activations import *
from fastAIcourse.init import *
from fastAIcourse.sgd import *
from fastAIcourse.resnet import *
from fastAIcourse.augment import *
from fastAIcourse.accel import *
from fastAIcourse.training import *
from fastprogress import progress_bar
from diffusers import UNet2DModel, DDIMPipeline, DDPMPipeline, DDIMScheduler, DDPMScheduler
=5, linewidth=140, sci_mode=False)
torch.set_printoptions(precision1)
torch.manual_seed('image.cmap'] = 'gray_r'
mpl.rcParams['figure.dpi'] = 70
mpl.rcParams[
import logging
logging.disable(logging.WARNING)
42)
set_seed(if fc.defaults.cpus>8: fc.defaults.cpus=8
= 'image','label'
xl,yl = "fashion_mnist"
name = 1000
n_steps = 512
bs = load_dataset(name) dsd
= 0.66 sig_data
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]]
def scalings(sig):
= sig**2+sig_data**2
totvar # c_skip,c_out,c_in
return sig_data**2/totvar,sig*sig_data/totvar.sqrt(),1/totvar.sqrt()
def noisify(x0):
= x0.device
device = (torch.randn([len(x0)])*1.2-1.2).exp().to(x0).reshape(-1,1,1,1)
sig = torch.randn_like(x0, device=device)
noise = scalings(sig)
c_skip,c_out,c_in = x0 + noise*sig
noised_input = (x0-c_skip*noised_input)/c_out
target return (noised_input*c_in,sig.squeeze()),target
def collate_ddpm(b): return noisify(default_collate(b)[xl])
def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=0)
= dsd.with_transform(transformi)
tds = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test'])) dls
Train
Based on Diffusers
def unet_conv(ni, nf, ks=3, stride=1, act=nn.SiLU, norm=None, bias=True):
= nn.Sequential()
layers if norm: layers.append(norm(ni))
if act : layers.append(act())
=stride, kernel_size=ks, padding=ks//2, bias=bias))
layers.append(nn.Conv2d(ni, nf, stridereturn layers
class UnetResBlock(nn.Module):
def __init__(self, ni, nf=None, ks=3, act=nn.SiLU, norm=nn.BatchNorm2d):
super().__init__()
if nf is None: nf = ni
self.convs = nn.Sequential(unet_conv(ni, nf, ks, act=act, norm=norm),
=act, norm=norm))
unet_conv(nf, nf, ks, actself.idconv = fc.noop if ni==nf else nn.Conv2d(ni, nf, 1)
def forward(self, x): return self.convs(x) + self.idconv(x)
class A:
def __call__(self):
super().__call__()
print('a')
class B:
def __call__(self): print('b')
class C(A,B): pass
C()()
b
a
class SaveModule:
def forward(self, x, *args, **kwargs):
self.saved = super().forward(x, *args, **kwargs)
return self.saved
class SavedResBlock(SaveModule, UnetResBlock): pass
class SavedConv(SaveModule, nn.Conv2d): pass
def down_block(ni, nf, add_down=True, num_layers=1):
= nn.Sequential(*[SavedResBlock(ni=ni if i==0 else nf, nf=nf)
res for i in range(num_layers)])
if add_down: res.append(SavedConv(nf, nf, 3, stride=2, padding=1))
return res
def upsample(nf): return nn.Sequential(nn.Upsample(scale_factor=2.), nn.Conv2d(nf, nf, 3, padding=1))
class UpBlock(nn.Module):
def __init__(self, ni, prev_nf, nf, add_up=True, num_layers=2):
super().__init__()
self.resnets = nn.ModuleList(
if i==0 else nf)+(ni if (i==num_layers-1) else nf), nf)
[UnetResBlock((prev_nf for i in range(num_layers)])
self.up = upsample(nf) if add_up else nn.Identity()
def forward(self, x, ups):
for resnet in self.resnets: x = resnet(torch.cat([x, ups.pop()], dim=1))
return self.up(x)
class UNet2DModel(nn.Module):
def __init__( self, in_channels=3, out_channels=3, nfs=(224,448,672,896), num_layers=1):
super().__init__()
self.conv_in = nn.Conv2d(in_channels, nfs[0], kernel_size=3, padding=1)
= nfs[0]
nf self.downs = nn.Sequential()
for i in range(len(nfs)):
= nf
ni = nfs[i]
nf self.downs.append(down_block(ni, nf, add_down=i!=len(nfs)-1, num_layers=num_layers))
self.mid_block = UnetResBlock(nfs[-1])
= list(reversed(nfs))
rev_nfs = rev_nfs[0]
nf self.ups = nn.ModuleList()
for i in range(len(nfs)):
= nf
prev_nf = rev_nfs[i]
nf = rev_nfs[min(i+1, len(nfs)-1)]
ni self.ups.append(UpBlock(ni, prev_nf, nf, add_up=i!=len(nfs)-1, num_layers=num_layers+1))
self.conv_out = unet_conv(nfs[0], out_channels, act=nn.SiLU, norm=nn.BatchNorm2d)
def forward(self, inp):
= self.conv_in(inp[0])
x = [x]
saved = self.downs(x)
x += [p.saved for o in self.downs for p in o]
saved = self.mid_block(x)
x for block in self.ups: x = block(x, saved)
return self.conv_out(x)
= UNet2DModel(in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2) model
= 3e-3
lr = 25
epochs = partial(optim.Adam, eps=1e-5)
opt_func = epochs * len(dls.train)
tmax = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
sched = [DeviceCB(), MixedPrecision(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched)]
cbs = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func) learn
learn.fit(epochs)
Timesteps
= 16
emb_dim = torch.linspace(-10,10,100)
tsteps = 10000 max_period
10000) math.log(
9.210340371976184
= -math.log(max_period) * torch.linspace(0, 1, emb_dim//2, device=tsteps.device) exponent
; plt.plot(exponent)
= tsteps[:,None].float() * exponent.exp()[None,:]
emb emb.shape
torch.Size([100, 8])
0])
plt.plot(emb[10])
plt.plot(emb[20])
plt.plot(emb[50])
plt.plot(emb[-1]); plt.plot(emb[
= torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
emb emb.shape
torch.Size([100, 16])
0])
plt.plot(emb[:,1])
plt.plot(emb[:,2])
plt.plot(emb[:,3])
plt.plot(emb[:,4]); plt.plot(emb[:,
8])
plt.plot(emb[:,9])
plt.plot(emb[:,10]); plt.plot(emb[:,
=(7,7)); show_image(emb.T, figsize
def timestep_embedding(tsteps, emb_dim, max_period= 10000):
= -math.log(max_period) * torch.linspace(0, 1, emb_dim//2, device=tsteps.device)
exponent = tsteps[:,None].float() * exponent.exp()[None,:]
emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
emb return F.pad(emb, (0,1,0,0)) if emb_dim%2==1 else emb
32, max_period=1000).T, figsize=(7,7)); show_image(timestep_embedding(tsteps,
32, max_period=10).T, figsize=(7,7)); show_image(timestep_embedding(tsteps,
Timestep model
from functools import wraps
def lin(ni, nf, act=nn.SiLU, norm=None, bias=True):
= nn.Sequential()
layers if norm: layers.append(norm(ni))
if act : layers.append(act())
=bias))
layers.append(nn.Linear(ni, nf, biasreturn layers
class EmbResBlock(nn.Module):
def __init__(self, n_emb, ni, nf=None, ks=3, act=nn.SiLU, norm=nn.BatchNorm2d):
super().__init__()
if nf is None: nf = ni
self.emb_proj = nn.Linear(n_emb, nf*2)
self.conv1 = unet_conv(ni, nf, ks, act=act, norm=norm) #, bias=not norm)
self.conv2 = unet_conv(nf, nf, ks, act=act, norm=norm)
self.idconv = fc.noop if ni==nf else nn.Conv2d(ni, nf, 1)
def forward(self, x, t):
= x
inp = self.conv1(x)
x = self.emb_proj(F.silu(t))[:, :, None, None]
emb = torch.chunk(emb, 2, dim=1)
scale,shift = x*(1+scale) + shift
x = self.conv2(x)
x return x + self.idconv(inp)
def saved(m, blk):
= m.forward
m_
@wraps(m.forward)
def _f(*args, **kwargs):
= m_(*args, **kwargs)
res
blk.saved.append(res)return res
= _f
m.forward return m
class DownBlock(nn.Module):
def __init__(self, n_emb, ni, nf, add_down=True, num_layers=1):
super().__init__()
self.resnets = nn.ModuleList([saved(EmbResBlock(n_emb, ni if i==0 else nf, nf), self)
for i in range(num_layers)])
self.down = saved(nn.Conv2d(nf, nf, 3, stride=2, padding=1), self) if add_down else nn.Identity()
def forward(self, x, t):
self.saved = []
for resnet in self.resnets: x = resnet(x, t)
= self.down(x)
x return x
class UpBlock(nn.Module):
def __init__(self, n_emb, ni, prev_nf, nf, add_up=True, num_layers=2):
super().__init__()
self.resnets = nn.ModuleList(
if i==0 else nf)+(ni if (i==num_layers-1) else nf), nf)
[EmbResBlock(n_emb, (prev_nf for i in range(num_layers)])
self.up = upsample(nf) if add_up else nn.Identity()
def forward(self, x, t, ups):
for resnet in self.resnets: x = resnet(torch.cat([x, ups.pop()], dim=1), t)
return self.up(x)
class EmbUNetModel(nn.Module):
def __init__( self, in_channels=3, out_channels=3, nfs=(224,448,672,896), num_layers=1):
super().__init__()
self.conv_in = nn.Conv2d(in_channels, nfs[0], kernel_size=3, padding=1)
self.n_temb = nf = nfs[0]
= nf*4
n_emb # TODO: remove act func from 1st MLP layer
self.emb_mlp = nn.Sequential(lin(self.n_temb, n_emb, norm=nn.BatchNorm1d),
lin(n_emb, n_emb))self.downs = nn.ModuleList()
for i in range(len(nfs)):
= nf
ni = nfs[i]
nf self.downs.append(DownBlock(n_emb, ni, nf, add_down=i!=len(nfs)-1, num_layers=num_layers))
self.mid_block = EmbResBlock(n_emb, nfs[-1])
= list(reversed(nfs))
rev_nfs = rev_nfs[0]
nf self.ups = nn.ModuleList()
for i in range(len(nfs)):
= nf
prev_nf = rev_nfs[i]
nf = rev_nfs[min(i+1, len(nfs)-1)]
ni self.ups.append(UpBlock(n_emb, ni, prev_nf, nf, add_up=i!=len(nfs)-1, num_layers=num_layers+1))
self.conv_out = unet_conv(nfs[0], out_channels, act=nn.SiLU, norm=nn.BatchNorm2d, bias=False)
def forward(self, inp):
= inp
x,t = timestep_embedding(t, self.n_temb)
temb = self.emb_mlp(temb)
emb = self.conv_in(x)
x = [x]
saved for block in self.downs: x = block(x, emb)
+= [p for o in self.downs for p in o.saved]
saved = self.mid_block(x, emb)
x for block in self.ups: x = block(x, emb, saved)
return self.conv_out(x)
= EmbUNetModel(in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2) model
= 1e-2
lr = 25
epochs = partial(optim.Adam, eps=1e-5)
opt_func = epochs * len(dls.train)
tmax = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
sched = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]
cbs = EmbUNetModel(in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2)
model = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func) learn
learn.fit(epochs)
loss | epoch | train |
---|---|---|
0.409 | 0 | train |
0.304 | 0 | eval |
0.221 | 1 | train |
0.338 | 1 | eval |
0.193 | 2 | train |
0.215 | 2 | eval |
0.182 | 3 | train |
0.219 | 3 | eval |
0.175 | 4 | train |
0.201 | 4 | eval |
0.169 | 5 | train |
0.206 | 5 | eval |
0.165 | 6 | train |
0.240 | 6 | eval |
0.162 | 7 | train |
0.180 | 7 | eval |
0.157 | 8 | train |
0.186 | 8 | eval |
0.155 | 9 | train |
0.222 | 9 | eval |
0.153 | 10 | train |
0.190 | 10 | eval |
0.151 | 11 | train |
0.164 | 11 | eval |
0.149 | 12 | train |
0.186 | 12 | eval |
0.148 | 13 | train |
0.158 | 13 | eval |
0.146 | 14 | train |
0.146 | 14 | eval |
0.145 | 15 | train |
0.152 | 15 | eval |
0.143 | 16 | train |
0.148 | 16 | eval |
0.143 | 17 | train |
0.142 | 17 | eval |
0.142 | 18 | train |
0.142 | 18 | eval |
0.140 | 19 | train |
0.140 | 19 | eval |
0.139 | 20 | train |
0.138 | 20 | eval |
0.139 | 21 | train |
0.139 | 21 | eval |
0.137 | 22 | train |
0.139 | 22 | eval |
0.137 | 23 | train |
0.138 | 23 | eval |
0.138 | 24 | train |
0.137 | 24 | eval |
Sampling
from miniai.fid import ImageEval
= torch.load('models/data_aug2.pkl')
cmodel del(cmodel[8])
del(cmodel[7])
= 2048
bs = dsd.with_transform(transformi)
tds2 = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)
dls2
= dls2.train
dt = next(iter(dt))
xb,yb
= ImageEval(cmodel, dls2, cbs=[DeviceCB()]) ie
= (512,1,32,32) sz
= (2048,1,32,32) sz
def sigmas_karras(n, sigma_min=0.01, sigma_max=80., rho=7.):
= torch.linspace(0, 1, n)
ramp = sigma_min**(1/rho)
min_inv_rho = sigma_max**(1/rho)
max_inv_rho = (max_inv_rho + ramp * (min_inv_rho-max_inv_rho))**rho
sigmas return torch.cat([sigmas, tensor([0.])]).cuda()
def denoise(model, x, sig):
= sig[None] #* torch.ones((len(x),1), device=x.device)
sig = scalings(sig)
c_skip,c_out,c_in return model((x*c_in, sig))*c_out + x*c_skip
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
if not eta: return sigma_to, 0.
= sigma_to**2,sigma_from**2
var_to,var_from = min(sigma_to, eta * (var_to * (var_from-var_to)/var_from)**0.5)
sigma_up return (var_to-sigma_up**2)**0.5, sigma_up
@torch.no_grad()
def sample_euler_ancestral(x, sigs, i, model, eta=1.):
= sigs[i],sigs[i+1]
sig,sig2 = denoise(model, x, sig)
denoised = get_ancestral_step(sig, sig2, eta=eta)
sigma_down,sigma_up = x + (x-denoised)/sig*(sigma_down-sig)
x return x + torch.randn_like(x)*sigma_up
@torch.no_grad()
def sample_euler(x, sigs, i, model):
= sigs[i],sigs[i+1]
sig,sig2 = denoise(model, x, sig)
denoised return x + (x-denoised)/sig*(sig2-sig)
@torch.no_grad()
def sample_heun(x, sigs, i, model, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
= sigs[i],sigs[i+1]
sig,sig2 = len(sigs)
n = min(s_churn/(n-1), 2**0.5-1) if s_tmin<=sig<=s_tmax else 0.
gamma = torch.randn_like(x) * s_noise
eps = sig * (gamma+1)
sigma_hat if gamma > 0: x = x + eps * (sigma_hat**2-sig**2)**0.5
= denoise(model, x, sig)
denoised = (x-denoised)/sig
d = sig2-sigma_hat
dt = x + d*dt
x_2 if sig2==0: return x_2
= denoise(model, x_2, sig2)
denoised_2 = (x_2-denoised_2)/sig2
d_2 = (d+d_2)/2
d_prime return x + d_prime*dt
def sample(sampler, model, steps=100, sigma_max=80., **kwargs):
= []
preds = torch.randn(sz).cuda()*sigma_max
x = sigmas_karras(steps, sigma_max=sigma_max)
sigs for i in progress_bar(range(len(sigs)-1)):
= sampler(x, sigs, i, model, **kwargs)
x
preds.append(x)return preds
= sample_lms(model, steps=20, order=3)
preds # preds = sample(sample_euler_ancestral, model, steps=100, eta=1.)
# preds = sample(sample_euler, model, steps=100)
# preds = sample(sample_heun, model, steps=20, s_churn=0.5)
100.00% [20/20 00:04<00:00]
= preds[-1]
s min(),s.max() s.
(tensor(-1.09312, device='cuda:0'), tensor(1.43464, device='cuda:0'))
25].clamp(-1,1), imsize=1.5) show_images(s[:
# lms 20
ie.fid(s),ie.kid(s),s.shape
(6.195896366748002, 0.011938275769352913, torch.Size([2048, 1, 32, 32]))
= sample_lms(model, steps=20, order=3)
preds = preds[-1]
s ie.fid(s),ie.kid(s),s.shape
100.00% [20/20 00:04<00:00]
(4.967668251150826, 0.01714729703962803, torch.Size([2048, 1, 32, 32]))
= sample_lms(model, steps=20, order=3)
preds = preds[-1]
s ie.fid(s),ie.kid(s),s.shape
100.00% [20/20 00:04<00:00]
(4.607266664456915, 0.0245591439306736, torch.Size([2048, 1, 32, 32]))
from scipy import integrate
def linear_multistep_coeff(order, t, i, j):
if order-1 > i: raise ValueError(f'Order {order} too high for step {i}')
def fn(tau):
= 1.
prod for k in range(order):
if j == k: continue
*= (tau-t[i-k]) / (t[i-j]-t[i-k])
prod return prod
return integrate.quad(fn, t[i], t[i+1], epsrel=1e-4)[0]
@torch.no_grad()
def sample_lms(model, steps=100, order=4, sigma_max=80.):
= []
preds = torch.randn(sz).cuda()*sigma_max
x = sigmas_karras(steps, sigma_max=sigma_max)
sigs = []
ds for i in progress_bar(range(len(sigs)-1)):
= sigs[i]
sig = denoise(model, x, sig)
denoised = (x-denoised)/sig
d
ds.append(d)if len(ds) > order: ds.pop(0)
= min(i+1, order)
cur_order = [linear_multistep_coeff(cur_order, sigs, i, j) for j in range(cur_order)]
coeffs = x + sum(coeff*d for coeff, d in zip(coeffs, reversed(ds)))
x
preds.append(x)return preds