from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder
'image.cmap'] = 'gray_r'
mpl.rcParams[ logging.disable(logging.WARNING)
Denoising Diffusion Probabilistic Models with miniai
- skip_showdoc: true
Imports
Load the dataset
= 'image','label'
xl,yl = "fashion_mnist"
name = load_dataset(name)
dsd
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2)) for o in b[xl]]
= 128
bs = dsd.with_transform(transformi)
tds = DataLoaders.from_dd(tds, bs, num_workers=8) dls
= dls.train
dt = next(iter(dt)) xb,yb
16], imsize=1.5) show_images(xb[:
= 0.0001,0.02,1000
betamin,betamax,n_steps = torch.linspace(betamin, betamax, n_steps)
beta = 1.-beta
alpha = alpha.cumprod(dim=0)
alphabar = beta.sqrt() sigma
; plt.plot(beta)
; plt.plot(sigma)
; plt.plot(alphabar)
Exported source
def noisify(x0, ᾱ):
= x0.device
device = len(x0)
n = torch.randint(0, n_steps, (n,), dtype=torch.long)
t = torch.randn(x0.shape, device=device)
ε = ᾱ[t].reshape(-1, 1, 1, 1).to(device)
ᾱ_t = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε
xt return (xt, t.to(device)), ε
= noisify(xb[:25],alphabar)
(xt,t),ε t
tensor([ 26, 335, 620, 924, 950, 113, 378, 14, 210, 954, 231, 572, 315, 295, 567, 706, 749, 876, 73, 111, 899, 213, 541, 769, 287])
= fc.map_ex(t, '{}')
titles =1.5, titles=titles) show_images(xt, imsize
Training
Exported source
from diffusers import UNet2DModel
Exported source
@torch.no_grad()
def sample(model, sz, alpha, alphabar, sigma, n_steps):
= next(model.parameters()).device
device = torch.randn(sz, device=device)
x_t = []
preds for t in reversed(range(n_steps)):
= torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
t_batch = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(device)
z = alphabar[t-1] if t > 0 else torch.tensor(1)
ᾱ_t1 = 1 - alphabar[t]
b̄_t = 1 - ᾱ_t1
b̄_t1 = ((x_t - b̄_t.sqrt() * learn.model((x_t, t_batch)))/alphabar[t].sqrt()).clamp(-1,1)
x_0_hat = x_0_hat * ᾱ_t1.sqrt()*(1-alpha[t])/b̄_t + x_t * alpha[t].sqrt()*b̄_t1/b̄_t + sigma[t]*z
x_t
preds.append(x_t.cpu())return preds
Exported source
class DDPMCB(Callback):
= DeviceCB.order+1
order def __init__(self, n_steps, beta_min, beta_max):
super().__init__()
fc.store_attr()self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps)
self.α = 1. - self.beta
self.ᾱ = torch.cumprod(self.α, dim=0)
self.σ = self.beta.sqrt()
def before_batch(self, learn): learn.batch = noisify(learn.batch[0], self.ᾱ)
def sample(self, model, sz): return sample(model, sz, self.α, self.ᾱ, self.σ, self.n_steps)
Exported source
class UNet(UNet2DModel):
def forward(self, x): return super().forward(*x).sample
= DDPMCB(n_steps=1000, beta_min=0.0001, beta_max=0.02) ddpm_cb
= UNet(in_channels=1, out_channels=1, block_out_channels=(16, 32, 64, 64), norm_num_groups=8)
model
= TrainLearner(model, dls, nn.MSELoss())
learn =False, cbs=[ddpm_cb,SingleBatchCB()])
learn.fit(train= learn.batch
(xt,t),ε 25], titles=fc.map_ex(t[:25], '{}'), imsize=1.5) show_images(xt[:
= 5e-3
lr = 3
epochs
= epochs * len(dls.train)
tmax = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
sched = [ddpm_cb, DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched)] cbs
= UNet(in_channels=1, out_channels=1, block_out_channels=(16, 32, 64, 128), norm_num_groups=8) model
Exported source
def init_ddpm(model):
for o in model.down_blocks:
for p in o.resnets:
p.conv2.weight.data.zero_()for p in fc.L(o.downsamplers): init.orthogonal_(p.conv.weight)
for o in model.up_blocks:
for p in o.resnets: p.conv2.weight.data.zero_()
model.conv_out.weight.data.zero_()
init_ddpm(model)
= partial(optim.Adam, eps=1e-5) opt_func
= TrainLearner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func) learn
learn.fit(epochs)
loss | epoch | train |
---|---|---|
0.131 | 0 | train |
0.024 | 0 | eval |
0.022 | 1 | train |
0.022 | 1 | eval |
0.019 | 2 | train |
0.020 | 2 | eval |
= Path('models') mdl_path
/'fashion_ddpm2.pkl') torch.save(learn.model, mdl_path
= torch.load(mdl_path/'fashion_ddpm2.pkl') learn.model
Sampling
= ddpm_cb.sample(learn.model, (16, 1, 32, 32)) samples
-1], figsize=(5,5)) show_images(samples[
Mixed Precision
= 512 bs
next(iter(DataLoader(tds['train'], batch_size=2)))
{'image': tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]]),
'label': tensor([9, 0])}
Exported source
def collate_ddpm(b): return noisify(default_collate(b)[xl], alphabar)
def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=4)
= DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test'])) dls
Exported source
class MixedPrecision(TrainCB):
= DeviceCB.order+10
order
def before_fit(self, learn): self.scaler = torch.cuda.amp.GradScaler()
def before_batch(self, learn):
self.autocast = torch.autocast("cuda", dtype=torch.float16)
self.autocast.__enter__()
def after_loss(self, learn): self.autocast.__exit__(None, None, None)
def backward(self, learn): self.scaler.scale(learn.loss).backward()
def step(self, learn):
self.scaler.step(learn.opt)
self.scaler.update()
= 1e-2
lr = 3
epochs = epochs * len(dls.train)
tmax = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
sched = [DeviceCB(), MixedPrecision(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched)]
cbs = UNet(in_channels=1, out_channels=1, block_out_channels=(16, 32, 64, 128), norm_num_groups=8)
model
init_ddpm(model)= Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func) learn
learn.fit(epochs)
loss | epoch | train |
---|---|---|
nan | 0 | train |
nan | 0 | eval |
nan | 1 | train |
nan | 1 | eval |
nan | 2 | train |
nan | 2 | eval |
= sample(learn.model, (32, 1, 32, 32), alpha, alphabar, sigma, n_steps) samples
-1][:25], imsize=1.5) show_images(samples[
'models/fashion_ddpm_mp.pkl') torch.save(learn.model,
Accelerate
pip install accelerate
before running this section.
Exported source
from accelerate import Accelerator
Exported source
class AccelerateCB(TrainCB):
= DeviceCB.order+10
order def __init__(self, n_inp=1, mixed_precision="fp16"):
super().__init__(n_inp=n_inp)
self.acc = Accelerator(mixed_precision=mixed_precision)
def before_fit(self, learn):
= self.acc.prepare(
learn.model,learn.opt,learn.dls.train,learn.dls.valid
learn.model, learn.opt, learn.dls.train, learn.dls.valid)
def backward(self, learn): self.acc.backward(learn.loss)
Exported source
def noisify(x0, ᾱ):
= x0.device
device = len(x0)
n = torch.randint(0, n_steps, (n,), dtype=torch.long)
t = torch.randn(x0.shape, device=device)
ε = ᾱ[t].reshape(-1, 1, 1, 1).to(device)
ᾱ_t = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε
xt return xt, t.to(device), ε
= DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test'])) dls
Exported source
class DDPMCB2(Callback):
def after_predict(self, learn): learn.preds = learn.preds.sample
= UNet2DModel(in_channels=1, out_channels=1, block_out_channels=(16, 32, 64, 128), norm_num_groups=8)
model
init_ddpm(model)= [DDPMCB2(), DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), AccelerateCB(n_inp=2)]
cbs = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func) learn
learn.fit(epochs)
loss | epoch | train |
---|---|---|
0.201 | 0 | train |
0.031 | 0 | eval |
0.025 | 1 | train |
0.022 | 1 | eval |
0.022 | 2 | train |
0.021 | 2 | eval |
A sneaky trick
Exported source
class MultDL:
def __init__(self, dl, mult=2): self.dl,self.mult = dl,mult
def __len__(self): return len(self.dl)*self.mult
def __iter__(self):
for o in self.dl:
for i in range(self.mult): yield o
= MultDL(dls.train) dls.train