Diffusion Models from scratch (MNIST data) | Tutorial in 100 lines of PyTorch code

Denoising Diffusion Probabilistic Models

Papers in 100 Lines of Code
3 min readSep 28, 2023
MNIST generated samples | Diffusion Model Tutorial

In my previous story, I implemented the initial paper on diffusion model. In this tutorial, we are going to implement the paper Denoising Diffusion Probabilistic Models, which presented high quality image synthesis results using diffusion probabilistic models in 2020. While the results were impressive at that time, it is impressive to see the progress that have been made since then. The next three years are going to be fascinating!

In this tutorial, we are going to mainly focus on the implementation of the diffusion model part. If you need more details about the maths and mechanisms behind the algorithm, I have a full, 10-hour course about it on Udemy.

Algorithm 1 Training

Let us start by implementing the Algorithm 1 described in the paper, which is dedicated to training the model. It is a straightforward implementation of the pseudocode from the paper. The function approximator is a neural network conditioned on time. It is taken as input to the constructor of the Diffusion Model class, which we will define below. The variable alpha_bar is also defined in the constructor.

Denoising Diffusion Probabilistic Models (DDPM) | Tutorial in PyTorch

Algorithm 2 Sampling

The second algorithm, dedicated to sampling, is also a straightforward implementation of the pseudocode from the paper. It is not needed during training, and therefore we add the torch.no_grad decorator to prevent gradient computation, and thus reduce its computational footprint.

Denoising Diffusion Probabilistic Models (DDPM) | Tutorial in PyTorch

Diffusion Model

The Diffusion Model constructor takes as input the number of time steps T, a conditional function approximator, and the device on which it should run. Then, we create the variables related to variance scheduling with the same hyperparameters from the paper.

Denoising Diffusion Probabilistic Models (DDPM) | Tutorial in PyTorch

Putting everything together

While it may sound too easy, that’s it! We have implemented a Diffusion Model for images very intuitively, by following the pseudocode from the paper, in only a few lines of code.

We can then put everything together by importing a few modules, and preparing the training data. We are using the MNIST data, and create a function to sample from it. If you want to use your own data, you can simply modify this function. The image resolution should be a multiple of 16, and this is why we resize the MNIST data from 28 x 28 to 32 x 32 with the interpolate function from PyTorch.

The U-Net module (as well as the full code) is available in this GitHub repository — but you can feed other models as function approximatiors. In this tutorial, I want to mainly focus on the diffusion model part and so, I will not focus on the U-Net implementation, but it was implemented from scratch in my course. It is the same model used in the paper on the CIFAR10 data, with 35.7M parameters. It uses self-attention at the 16 x 16 feature map resolution.

Denoising Diffusion Probabilistic Models (DDPM) | Tutorial in PyTorch

Now that we have all the data, functions, and classes we need, we can put everything together, and start the training process. After a few hours of training, you should get nice samples from the MNIST distribution, similar to the images from the beginning of that tutorial.

Denoising Diffusion Probabilistic Models (DDPM) | Tutorial in PyTorch

I hope you found this story helpful! If it provided value to you, please consider showing your appreciation by clapping for this story, and don’t forget to subscribe to stay updated on more tutorials and content related to Diffusion Models and Machine Learning.

[Full Code] | [Diffusion Model Course] | [Consulting] | [Career & Internships]

--

--