import timm, torch, random, datasets, math, fastcore.all as fc, numpy as np, matplotlib as mpl, matplotlib.pyplot as plt
import k_diffusion as K, torchvision.transforms as T
import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch.utils.data import DataLoader,default_collate
from pathlib import Path
from torch.nn import init
from fastcore.foundation import L
from torch import nn,tensor
from datasets import load_dataset
from operator import itemgetter
from torcheval.metrics import MulticlassAccuracy
from functools import partial
from torch.optim import lr_scheduler
from torch import optim
from fastAIcourse.datasets import *
from fastAIcourse.conv import *
from fastAIcourse.learner import *
from fastAIcourse.activations import *
from fastAIcourse.init import *
from fastAIcourse.sgd import *
from fastAIcourse.resnet import *
from fastAIcourse.augment import *
from fastAIcourse.accel import *
Karras pre-conditioning
from fastprogress import progress_bar
from diffusers import UNet2DModel, DDIMPipeline, DDPMPipeline, DDIMScheduler, DDPMScheduler
=5, linewidth=140, sci_mode=False)
torch.set_printoptions(precision1)
torch.manual_seed('image.cmap'] = 'gray_r'
mpl.rcParams['figure.dpi'] = 70
mpl.rcParams[
import logging
logging.disable(logging.WARNING)
42)
set_seed(if fc.defaults.cpus>8: fc.defaults.cpus=8
= 'image','label'
xl,yl = "fashion_mnist"
name = 1000
n_steps = 512
bs = load_dataset(name) dsd
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]]
= dsd.with_transform(transformi)
tds = DataLoaders.from_dd(tds, bs)
dls
= dls.train
dl = b = next(iter(dl)) xb,yb
# sig_data = xb.std()
= 0.66 sig_data
y
is clean signal, n
is N(0,1)
noise.
def scalings(sig):
= sig**2+sig_data**2
totvar # c_skip,c_out,c_in
return sig_data**2/totvar,sig*sig_data/totvar.sqrt(),1/totvar.sqrt()
= (torch.randn([10000])*1.2-1.2).exp() sig_samp
; plt.hist(sig_samp)
import seaborn as sns
=(0,10)); sns.kdeplot(sig_samp, clip
def noisify(x0):
= x0.device
device = (torch.randn([len(x0)])*1.2-1.2).exp().to(x0).reshape(-1,1,1,1)
sig = torch.randn_like(x0, device=device)
noise = scalings(sig)
c_skip,c_out,c_in = x0 + noise*sig
noised_input = (x0-c_skip*noised_input)/c_out
target return (noised_input*c_in,sig.squeeze()),target
def collate_ddpm(b): return noisify(default_collate(b)[xl])
def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=8)
= DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test'])) dls
= dls.train
dl = b = next(iter(dl)) (noised_input,sig),target
25], imsize=1.5, titles=fc.map_ex(sig[:25], '{:.02f}')) show_images(noised_input[:
25], imsize=1.5, titles=fc.map_ex(sig[:25], '{:.02f}')) show_images(target[:
noised_input.mean(),noised_input.std(),target.mean(),target.std()
(tensor(-0.69019), tensor(1.01665), tensor(-0.40007), tensor(1.03293))
Train
class UNet(UNet2DModel):
def forward(self, x): return super().forward(*x).sample
def init_ddpm(model):
for o in model.down_blocks:
for p in o.resnets:
p.conv2.weight.data.zero_()for p in fc.L(o.downsamplers): init.orthogonal_(p.conv.weight)
for o in model.up_blocks:
for p in o.resnets: p.conv2.weight.data.zero_()
model.conv_out.weight.data.zero_()
= 1e-2
lr = 25
epochs = partial(optim.Adam, eps=1e-5)
opt_func = epochs * len(dls.train)
tmax = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
sched = [DeviceCB(), MixedPrecision(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched)]
cbs = UNet(in_channels=1, out_channels=1, block_out_channels=(32, 64, 128, 256), norm_num_groups=8)
model
init_ddpm(model)= Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func) learn
learn.fit(epochs)
loss | epoch | train |
---|---|---|
0.703 | 0 | train |
0.348 | 0 | eval |
0.263 | 1 | train |
0.224 | 1 | eval |
0.201 | 2 | train |
0.197 | 2 | eval |
0.188 | 3 | train |
0.186 | 3 | eval |
0.179 | 4 | train |
0.179 | 4 | eval |
0.172 | 5 | train |
0.174 | 5 | eval |
0.169 | 6 | train |
0.165 | 6 | eval |
0.160 | 7 | train |
0.167 | 7 | eval |
0.158 | 8 | train |
0.161 | 8 | eval |
0.154 | 9 | train |
0.164 | 9 | eval |
0.152 | 10 | train |
0.151 | 10 | eval |
0.151 | 11 | train |
0.153 | 11 | eval |
0.147 | 12 | train |
0.150 | 12 | eval |
0.147 | 13 | train |
0.150 | 13 | eval |
0.145 | 14 | train |
0.146 | 14 | eval |
0.143 | 15 | train |
0.144 | 15 | eval |
0.142 | 16 | train |
0.144 | 16 | eval |
0.141 | 17 | train |
0.141 | 17 | eval |
0.140 | 18 | train |
0.141 | 18 | eval |
0.139 | 19 | train |
0.140 | 19 | eval |
0.139 | 20 | train |
0.140 | 20 | eval |
0.138 | 21 | train |
0.138 | 21 | eval |
0.137 | 22 | train |
0.137 | 22 | eval |
0.137 | 23 | train |
0.139 | 23 | eval |
0.137 | 24 | train |
0.137 | 24 | eval |
# torch.save(learn.model, 'models/fashion_karras.pkl')
# model = learn.model = torch.load('models/fashion_karras.pkl').cuda()
def denoise(target, noised_input): return target*c_out + noised_input*c_skip
with torch.no_grad():
= sig.cuda().reshape(-1,1,1,1)
sigr = scalings(sigr)
c_skip,c_out,c_in = learn.model((noised_input.cuda(),sig.cuda()))
targ_pred = denoise(targ_pred, noised_input.cuda()/c_in) x0_pred
25], imsize=1.5, titles=fc.map_ex(sig[:25], '{:.02f}')) show_images(noised_input[:
25].clamp(-1,1), imsize=1.5, titles=fc.map_ex(sig[:25], '{:.02f}')) show_images(x0_pred[:
/c_in)[:25], imsize=1.5, titles=fc.map_ex(sig[:25], '{:.02f}')) show_images(denoise(target.cuda(), noised_input.cuda()
= tensor(80.).cuda().reshape(-1,1,1,1)
sig_r = scalings(sig_r)
c_skip,c_out,c_in = torch.randn(32,1,32,32).to(model.device)*sig_r
x_r with torch.no_grad():
= learn.model((x_r*c_in,sig_r.squeeze()))
targ_pred = denoise(targ_pred, x_r)
x0_pred 25], imsize=1.5) show_images(x0_pred[:
max(),x0_pred.min(),x0_pred.mean(),x0_pred.std() x0_pred.
(tensor(0.63882, device='cuda:0'),
tensor(-1.18548, device='cuda:0'),
tensor(-0.54164, device='cuda:0'),
tensor(0.43493, device='cuda:0'))
Sampling
from miniai.fid import ImageEval
= torch.load('models/data_aug2.pkl')
cmodel del(cmodel[8])
del(cmodel[7])
= 2048
bs = dsd.with_transform(transformi)
tds = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)
dls
= dls.train
dt = next(iter(dt))
xb,yb
= ImageEval(cmodel, dls, cbs=[DeviceCB()]) ie
= (512,1,32,32) sz
= (2048,1,32,32) sz
def sigmas_karras(n, sigma_min=0.01, sigma_max=80., rho=7., device='cpu'):
= torch.linspace(0, 1, n)
ramp = sigma_min**(1/rho)
min_inv_rho = sigma_max**(1/rho)
max_inv_rho = (max_inv_rho + ramp * (min_inv_rho-max_inv_rho))**rho
sigmas return torch.cat([sigmas, tensor([0.])]).to(device)
= sigmas_karras(100)
sk ; plt.plot(sk)
def denoise(model, x, sig):
= scalings(sig)
c_skip,c_out,c_in return model((x*c_in, sig))*c_out + x*c_skip
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
if not eta: return sigma_to, 0.
= sigma_to**2,sigma_from**2
var_to,var_from = min(sigma_to, eta * (var_to * (var_from-var_to)/var_from)**0.5)
sigma_up return (var_to-sigma_up**2)**0.5, sigma_up
@torch.no_grad()
def sample_euler_ancestral(x, sigs, i, model, eta=1.):
= sigs[i],sigs[i+1]
sig,sig2 = denoise(model, x, sig)
denoised = get_ancestral_step(sig, sig2, eta=eta)
sigma_down,sigma_up = x + (x-denoised)/sig*(sigma_down-sig)
x return x + torch.randn_like(x)*sigma_up
@torch.no_grad()
def sample_euler(x, sigs, i, model):
= sigs[i],sigs[i+1]
sig,sig2 = denoise(model, x, sig)
denoised return x + (x-denoised)/sig*(sig2-sig)
@torch.no_grad()
def sample_heun(x, sigs, i, model, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
= sigs[i],sigs[i+1]
sig,sig2 = len(sigs)
n = min(s_churn/(n-1), 2**0.5-1) if s_tmin<=sig<=s_tmax else 0.
gamma = torch.randn_like(x) * s_noise
eps = sig * (gamma+1)
sigma_hat if gamma > 0: x = x + eps * (sigma_hat**2-sig**2)**0.5
= denoise(model, x, sig)
denoised = (x-denoised)/sig
d = sig2-sigma_hat
dt = x + d*dt
x_2 if sig2==0: return x_2
= denoise(model, x_2, sig2)
denoised_2 = (x_2-denoised_2)/sig2
d_2 = (d+d_2)/2
d_prime return x + d_prime*dt
def sample(sampler, model, steps=100, sigma_max=80., **kwargs):
= []
preds = torch.randn(sz).to(model.device)*sigma_max
x = sigmas_karras(steps, device=model.device, sigma_max=sigma_max)
sigs for i in progress_bar(range(len(sigs)-1)):
= sampler(x, sigs, i, model, **kwargs)
x
preds.append(x)return preds
# preds = sample_lms(model, steps=20, order=3)
# preds = sample(sample_euler_ancestral, model, steps=100, eta=0.5)
= sample(sample_euler, model, steps=100)
preds # preds = sample(sample_heun, model, steps=20, s_churn=0.5)
100.00% [20/20 00:09<00:00]
= preds[-1]
s min(),s.max() s.
(tensor(-1.08955, device='cuda:0'), tensor(1.46819, device='cuda:0'))
25].clamp(-1,1), imsize=1.5) show_images(s[:
# euler 100
ie.fid(s),ie.kid(s),s.shape
(5.231043207481207, 0.0031656520441174507, torch.Size([2048, 1, 32, 32]))
# euler 100
ie.fid(s),ie.kid(s),s.shape
(5.406003616592329, 0.015411057509481907, torch.Size([2048, 1, 32, 32]))
# ancestral 100 0.5
ie.fid(s),ie.kid(s),s.shape
(5.452807558586642, 0.0071729626506567, torch.Size([2048, 1, 32, 32]))
# heun 50
ie.fid(s),ie.kid(s),s.shape
(6.221842673288506, 0.023713070899248123, torch.Size([2048, 1, 32, 32]))
# heun 20
ie.fid(s),ie.kid(s),s.shape
(5.610681075267394, -0.005569742992520332, torch.Size([2048, 1, 32, 32]))
# heun 20, churn 0.5
ie.fid(s),ie.kid(s),s.shape
(5.2517917311790825, 0.026914160698652267, torch.Size([2048, 1, 32, 32]))
# lms 20
ie.fid(s),ie.kid(s),s.shape
(5.061003587997561, 0.019381564110517502, torch.Size([2048, 1, 32, 32]))
# reals
ie.fid(xb)
2.5736580178430586
from scipy import integrate
def linear_multistep_coeff(order, t, i, j):
if order-1 > i: raise ValueError(f'Order {order} too high for step {i}')
def fn(tau):
= 1.
prod for k in range(order):
if j == k: continue
*= (tau-t[i-k]) / (t[i-j]-t[i-k])
prod return prod
return integrate.quad(fn, t[i], t[i+1], epsrel=1e-4)[0]
@torch.no_grad()
def sample_lms(model, steps=100, order=4, sigma_max=80.):
= []
preds = torch.randn(sz).to(model.device)*sigma_max
x = sigmas_karras(steps, device=model.device, sigma_max=sigma_max)
sigs = []
ds for i in progress_bar(range(len(sigs)-1)):
= sigs[i]
sig = denoise(model, x, sig)
denoised = (x-denoised)/sig
d
ds.append(d)if len(ds) > order: ds.pop(0)
= min(i+1, order)
cur_order = [linear_multistep_coeff(cur_order, sigs, i, j) for j in range(cur_order)]
coeffs = x + sum(coeff*d for coeff, d in zip(coeffs, reversed(ds)))
x
preds.append(x)return preds