=2, linewidth=140, sci_mode=False)
torch.set_printoptions(precision1)
torch.manual_seed('image.cmap'] = 'gray_r'
mpl.rcParams[
import logging
logging.disable(logging.WARNING)
42)
set_seed(
if fc.defaults.cpus>8: fc.defaults.cpus=8
Denoising Diffusion Probabilistic Models with miniai
Now that we written our own barebones training library, let’s make some progress towards exploring diffusion model and building Stable Diffusion from scratch.
We’ll start with building and training the model described in the seminal 2020 paper Denoising Diffusion Probabilistic Models (DDPM). For more context, while diffusion models were technically invented back in 2015, diffusion models flew under the radar until this 2020 paper since they were complicated and difficult to train. The 2020 paper introducing DDPMs made some crucial assumptions that significantly simplify the model training and generation processes, as we will see here. Later versions of diffusion models all build upon the same framework introduced in this paper.
Let’s get started and train our own DDPM!
Imports
We’ll start with some imports.
'image.cmap'] = 'gray'
mpl.rcParams[ logging.disable(logging.WARNING)
Load the dataset
We will load the dataset from HuggingFace Hub:
= 'image','label'
x,y = "fashion_mnist"
name = load_dataset(name) dsd
To make life simpler (mostly with the model architecture), we’ll resize the 28x28 images to 32x32:
Exported source
@inplace
def transformi(b): b[x] = [TF.resize(TF.to_tensor(o), (32,32), antialias=True) for o in b[x]]
Let’s set our batch size and create our DataLoaders with this batch size. we can confirm the shapes are correct. Note that while we do get the labels for the dataset, we actuallydon’t care about that for our task of unconditional image generation.
42)
set_seed(= 128
bs = dsd.with_transform(transformi)
tds = DataLoaders.from_dd(tds, bs, num_workers=8)
dls = dls.train
dt = next(iter(dt))
xb,yb 10] xb.shape,yb[:
(torch.Size([128, 1, 32, 32]), tensor([5, 7, 4, 7, 3, 8, 9, 5, 3, 1]))
Create model
We will create a U-net. A U-net looks something like this:
The DDPM U-net is a modification of this with some modern tricks like using attention.
We will cover how U-nets are created and how modules like attention work in future lessons. For now, we’ll import the U-net from the diffusers library:
from diffusers import UNet2DModel
= UNet2DModel(in_channels=1, out_channels=1, block_out_channels=(32, 64, 128, 128)) model
Training - easy with a callback!
DDPM is trained quite simply in a few steps: 1. randomly select some timesteps in an iterative noising process. 2. Add noise corresponding to this timestep to the original image. For increasing timesteps, the variance of the noise increases. 3. Pass in this noisy image and the timestep to our model 4. Model is trained with an MSE loss between the model output and the amount of noise added to the image
We will implement this in a callback. The callback will randomly select the timestep and create the noisy image before setting up our input and ground truth tensors for the model forward pass and loss calculation.
After training, we need to sample from this model. This is an iterative denoising process starting from pure noise. We simply keep removing noise predicted by the neural network, but we do it with an expected noise schedule that is reverse of what we saw during training. This is also done in our callback.
Exported source
class DDPMCB(TrainCB):
= DeviceCB.order+1
order def __init__(self, n_steps, beta_min, beta_max):
super().__init__()
self.n_steps,self.βmin,self.βmax = n_steps,beta_min,beta_max
# variance schedule, linearly increased with timestep
self.β = torch.linspace(self.βmin, self.βmax, self.n_steps)
self.α = 1. - self.β
self.ᾱ = torch.cumprod(self.α, dim=0)
self.σ = self.β.sqrt()
def predict(self, learn): learn.preds = learn.model(*learn.batch[0]).sample
def before_batch(self, learn):
= learn.batch[0].device
device = torch.randn(learn.batch[0].shape, device=device) # noise, x_T
ε = learn.batch[0] # original images, x_0
x0 self.ᾱ = self.ᾱ.to(device)
= x0.shape[0]
n # select random timesteps
= torch.randint(0, self.n_steps, (n,), device=device, dtype=torch.long)
t = self.ᾱ[t].reshape(-1, 1, 1, 1).to(device)
ᾱ_t = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε #noisify the image
xt # input to our model is noisy image and timestep, ground truth is the noise
= ((xt, t), ε)
learn.batch
@torch.no_grad()
def sample(self, model, sz):
= next(model.parameters()).device
device = torch.randn(sz, device=device)
x_t = []
preds for t in reversed(range(self.n_steps)):
= torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
t_batch = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(device)
z = self.ᾱ[t-1] if t > 0 else torch.tensor(1)
ᾱ_t1 = 1 - self.ᾱ[t]
b̄_t = 1 - ᾱ_t1
b̄_t1 = learn.model(x_t, t_batch).sample
noise_pred = ((x_t - b̄_t.sqrt() * noise_pred)/self.ᾱ[t].sqrt()).clamp(-1,1)
x_0_hat = ᾱ_t1.sqrt()*(1-self.α[t])/b̄_t
x0_coeff = self.α[t].sqrt()*b̄_t1/b̄_t
xt_coeff = x_0_hat*x0_coeff + x_t*xt_coeff + self.σ[t]*z
x_t
preds.append(x_t.cpu())return preds
Okay now we’re ready to train a model!
Let’s create our Learner
. We’ll add our callbacks and train with MSE loss.
We specify the number of timesteps and the minimum and maximum variance for the DDPM model.
= 4e-3
lr = 3
epochs = epochs * len(dls.train)
tmax = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
sched = DDPMCB(n_steps=1000, beta_min=0.0001, beta_max=0.02)
ddpm_cb = [ddpm_cb, DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched)]
cbs = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=optim.Adam) learn
Now let’s run the fit function:
learn.fit(epochs)
loss | epoch | train |
---|---|---|
0.053 | 0 | train |
0.023 | 0 | eval |
0.020 | 1 | train |
0.018 | 1 | eval |
0.017 | 2 | train |
0.016 | 2 | eval |
= Path('models')
mdl_path =True) mdl_path.mkdir(exist_ok
/'fashion_ddpm.pkl') torch.save(learn.model, mdl_path
= torch.load(mdl_path/'fashion_ddpm.pkl') learn.model
Inference
Now that we’ve trained our model, let’s generate some images with our model:
42)
set_seed(= ddpm_cb.sample(learn.model, (16, 1, 32, 32))
samples len(samples)
1000
-samples[-1], figsize=(5,5)) show_images(
Let’s visualize the sampling process:
999]*10 [
[999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
import matplotlib.animation as animation
from IPython.display import display, HTML
= plt.subplots(figsize=(3,3))
fig,ax def _show_i(i): return show_image(-samples[i][9], ax=ax, animated=True).get_images()
= L.range(700,900, 4)+L.range(900,1000,1)+[999]*10
r = r.map(_show_i)
ims
= animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=3000)
animate
HTML(animate.to_jshtml())
Using matplotlib backend: <object object>
QStandardPaths: wrong permissions on runtime directory /run/user/1000/, 0755 instead of 0700
Note that I only take the steps between 800 and 1000 since most of the previous steps are actually quite noisy. This is a limitation of the noise schedule used for small images, and papers like Improved DDPM suggest other noise schedules for this purpose! (Some potential homework: try out the noise schedule from Improved DDPM and see if it helps.)