Diffusion Models from scratch | Tutorial in 100 lines of PyTorch code

Implementation of the initial paper on Diffusion Models

Papers in 100 Lines of Code
4 min readJun 5, 2023
Diffusion Models Tutorial

Diffusion Models, renowned for their image generation capabilities, are generative models that enable the creation of high-quality synthetic data. They have become a popular subject within the Machine Learning community, but have also gained significant attention from the public, largely attributed to Midjourney. The growing demand for Diffusion Models experts is evident in the escalating number of job postings by major tech companies, highlighting their widespread adoption in the industry.

In this tutorial, we will create a PyTorch implementation of the initial Diffusion Models paper from the ground up, using only 100 lines of Python code. Despite its simplicity and brevity, this implementation will yield great results, showcasing a faithful reproduction of Figure 1 featured in the foundational paper.

The main objective of this tutorial is to provide a step-by-step implementation of Diffusion Models, prioritizing the code rather than delving into the intricate details of each equation. If you express interest in a deeper understanding of the technology, the equations, and the advancements made since the original 2015 paper, I have created a course available on Udemy that offers in-depth coverage of the topic, discussing the advancements and developments in this field.

Forward process

In the forward process of Diffusion Models, the data undergoes a sequential noise addition until the complete degradation of the original signal, resulting in a well-behaved distribution, usually Gaussian. At each time step t, we sample the next data point, denoted as x_{t+1}, from the current data point x_{t} using a Gaussian distribution q(x_{t+1}|x_{t}). The mean and covariance of this distribution are determined by a set of hyper-parameters called beta. This is further described in the paper’s appendix.

To sample a data point x_{t}, instead of using a naive solution that involves iterating over t time steps, it has been demonstrated that the conditional distribution q(x_t|x_0) has a tractable analytical form. As a result, we can directly sample x_t as a function of x_0, which is done in line 7 of the implementation. This approach eliminates the need for a time-consuming iterative loop and allows for efficient sampling of x_t based on x_0.

While the mean and covariance of q(x_t|x_0) are used to sample x_t on line 7, we will need the mean, and standard deviation of q(x_{t-1}|x_t, x_0) in order to compute the loss. Those variables can be analytically computed using Bayes’ theorem and then directly incorporated into the code

Diffusion Models Tutorial

Reverse process

The purpose of the reverse process in Diffusion Models is to approximate the inverse of the forward process at each time step t. This implies that once the model is trained, we can directly sample noise from the well-behaved distribution and utilize the reverse process to generate data samples, such as 2D images.

The implementation of the reverse process is straightforward. It involves utilizing a time-dependent function approximator to estimate the distribution p(x_{t-1}|x_t) at each time step t. Due to the construction of the model, where q(x_t|x_{t-1}) is a Gaussian distribution with a small standard deviation, p(x_{t-1}|x_t) also becomes a Gaussian distribution. As a result, it can be parameterized by its mean and covariance matrix. The goal is to train the model to accurately estimate those parameters at each time step, enabling the generation of high-quality synthetic data.

After the model approximates the parameters (mean and covariance) of the distribution p(x_{t-1}|x_t), they can be utilized to sample from the distribution.

Diffusion Models Tutorial

Sampling

Once the reverse process is implemented, it can be recursively invoked, beginning with the sampling of noise from the well-behaved distribution. Through this recursive invocation, the reverse process generates data that is expected to resemble the training data.

Diffusion Models Tutorial

Constructor

To complete our model, we will create the constructor function. It takes several inputs, including a time-dependent model, the number of diffusion steps, and the device on which it will be executed. In line 10, we define the hyperparameters beta, which represents the variance at each diffusion step in the forward process. Additionally, in lines 12 and 13, we define variables that are functions of beta which helps to make the other part of the implementation more concise.

Diffusion Models Tutorial

Training

During training, the objective is to maximize the model’s log-likelihood, which, unfortunately, cannot be computed analytically. However, an alternative approach is to compute a lower bound on the model likelihood, as demonstrated in the paper. This lower bound involves calculating the Kullback-Leibler (KL) divergence between the distributions q(x_{t-1}|x_t) and p(x_{t-1}|x_t). Since both distributions are Gaussians, the KL divergence can be computed directly using its well-known analytical form. By optimizing the lower bound, we indirectly maximize the model log-likelihood and improve the model’s performance during training.

Diffusion Models PyTorch Tutorial

Putting it all together

Finally, all the components can be effortlessly combined to form the complete implementation. After a few hours of training, the model should be capable of producing a remarkable generative model for the swiss roll distribution.

Diffusion Models Tutorial in PyTorch

I hope you found this story helpful! If it provided value to you, please consider showing your appreciation by clapping for this story. Don’t forget to subscribe to stay updated on more tutorials and content related to Diffusion Models and Machine Learning.

Your support is greatly appreciated, and it motivates me to create more useful and informative material. Thank you!

[Full Code] | [Udemy Course] | [Consulting] | [Career & Internships]

Initial paper on Diffusion Models | Tutorial in PyTorch

--

--