Classifier free diffusion guidance
One of the key techniques in diffusion models that has significantly improved their performance is classifier-free guidance. In this post, we’ll explore what classifier-free guidance is, how it works, and implement it from scratch in PyTorch.
What is Classifier-Free Guidance?
At its core, classifier-free guidance is an elegant technique that allows us to control the generation process of diffusion models without requiring a separate classifier. The key insight is that we can create a more powerful conditional generation process by combining both conditional and unconditional generation in a clever way.
Think of it like having two artists working together:
- One artist (conditional model) who follows specific instructions
- One artist (unconditional model) who creates freely without constraints
By combining their perspectives with different weights, we can create results that are both high-quality and well-aligned with our desired conditions.
The Mathematics Behind Classifier-Free Guidance
The core equation for classifier-free guidance is surprisingly simple:
Where: - ε̃ is the guided noise prediction - w is the guidance weight - εθ(zt, c) is the conditional model prediction - εθ(zt, ∅) is the unconditional model prediction
The beauty of this approach is that it doesn’t require training two separate models. Instead, we train a single model that can handle both conditional and unconditional generation.
Implementation: A Complete Example
Let’s implement classifier-free guidance for a diffusion model from scratch. We’ll build a system that can generate MNIST-like digits conditioned on class labels.
First, let’s create our improved diffusion model:
Now, let’s implement an improved training loop with classifier-free guidance:
Finally, let’s improve the sampling process with classifier-free guidance:
Understanding the Improvements
Our implementation includes several key improvements over the basic version:
-
Enhanced Architecture: - Added time embeddings for better temporal understanding - Included layer normalization for stable training - Added residual connections in the U-Net structure
-
Improved Training: - Using AdamW optimizer with weight decay for better regularization - Implemented learning rate scheduling - Added gradient clipping to prevent exploding gradients - Weighted loss by timestep to focus more on later denoising steps
-
Better Sampling: - Improved timestep spacing using sigmoid scaling - More stable DDIM-like stepping procedure - Better handling of batch dimensions