Member-only story
Stabilizing GAN Training: A Deep Dive into Wasserstein GANs
Implementing WGANs from Scratch in PyTorch: Stable GAN Training in 100 Lines of Code

Generative Adversarial Networks (GANs) have revolutionised the field of generative modelling, enabling the creation of realistic images, music, and more. However, training GANs can be notoriously unstable, often suffering from issues like mode collapse and vanishing gradients. The Wasserstein GAN (WGAN) addresses these challenges by using a new distance metric, leading to more stable training and improved performance.
In this blog post, we’ll dive deep into the implementation of WGANs using PyTorch, building each component from scratch in just 100 lines of code. We’ll cover:
- The architecture of the generator and discriminator (also known as the critic in WGANs)
- The training algorithm specific to WGANs
- Weight initialization techniques
- Dataset preparation
- Helper functions
- The main function tying everything together
Wasserstein GAN: Generator
The generator in a WGAN is similar to that in a standard GAN. It takes a noise vector as input and generates a data sample (e.g., an image)
Wasserstein GAN: Discriminator (Critic)
In WGANs, the discriminator is referred to as the critic because it no longer classifies images as real or fake but instead scores them to approximate the Wasserstein distance.