Generative Adversarial Networks in 100 lines of code
In my previous post, I implemented the paper Auto-Encoding Variational Bayes in order to generate pictures of human faces. In this post, we will keep exploring generative models and will focus on generative adversarial networks.
We will implement the famous Generative Adversarial Networks paper in about 100 lines of code.
My aim is to reproduce results from the paper as closely as possible and therefore, this time we will not generate pictures of human faces but we will train our model on the MNIST dataset in order to generate handwritten digits.
The main idea behind generative adversarial networks is to jointly train a generator and a discriminator so as to maximize the discriminator’s capacity to distinguish real data and generated data (from the generator). Therefore, the discriminator is trained to maximize the following value function:
while the generator is trained so as to minimize the capacity of the discriminator to correctly label the data it generates:
The discriminator is a neural network that takes an image x as input and outputs a single number, the probability that x comes from the training dataset rather than from the generator. As in the paper, we will use a simple multilayer perceptron to model the discriminator.
The generator learns to map input noise z to handwritten digits and therefore, noise need to be sampled and fed to the generator in order to samples new data points.
The generator is also a simple multilayer perceptron.
Now that we have our two main pieces (the generator and discriminator), we can combine them in a main optimization function to train the networks. The implementation is straigthforward from the pseudo code of the paper.
Note that the objective equations have been rewritten in terms of the binary cross entropy loss. This is a safer way to implement the loss such that we do not need to worry about numerical instabilities as they are handled by the PyTorch implementation of the binary cross entropy.
Now that everything is ready, let us train for 50,000 epochs on the MNIST dataset:
Once trained, the generator allows to sample images that look similar to the ones from the training set:
The full code is available in the following GitHub repository.
I hope you enjoyed the story. If you did, please leave me a clap and follow me for similar content. On the generative modelling track, you may be interested in this post where I talk about Normalizing Flows.
If you want to dive into deep generative modelling, here are two great books I strongly recommend:
- Generative Deep Learning: Teaching Machines to Paint, Write, Compose, and Play : https://amzn.to/3xLG9Zh
- Generative AI with Python and TensorFlow 2: Create images, text, and music with VAEs, GANs, LSTMs, Transformer models : https://amzn.to/3g4Y9Ia
Disclosure: I only recommend products I use myself, and all opinions expressed here are my own. This post may contain affiliate links that are at no additional cost to you, I may earn a small commission.