Unpacking Distribution Matching Distillation

Om Rastogi
5 min read1 day ago

--

Diffusion models excel at generating realistic images but suffer from high computational costs due to their iterative denoising process. Distribution Matching Distillation (DMD) tackle this challenge by condensing multi-step diffusion processes into efficient one-step generators. This method utilize distribution matching loss formulations and GAN loss in order to regress from fake images to realistic images, paving the way for faster generative imaging applications.

Distribution Matching

As with other diffusion models, the one-step generator doesn’t explicitly learn the entire data distribution. Instead, it is forced to align with the target distribution. Hence instead of conducting step-by-step approximation, a one-step generator directly maps noisy samples to the target distribution.

That is where the distillation component comes in. Pre-trained teachers provide a high-quality intermediate representation of target distribution.

Step-by-Step Process of DMD

Step 0: Initialization

  1. One-step generator is initialised from pretrained diffusion unet, keeping timestep as T-1.
  2. real_unet is frozen teacher, that poses the real distribution
  3. fake_unet is there to model the generator’s data distribution

Step 1: Generate Image from Noise

  1. A random noise map is passed to the generator.
  2. The generator denoises this noise in one step, producing an image x.
  3. At this stage, x belongs to the generator’s probability density, p_fake.

Step 2: Add Gaussian Noise

  1. The image x is perturbed with Gaussian noise to obtain a noisy version xt.
  2. The timestep t is sampled uniformly between 0.2T and 0.98T (excluding extreme no-noise or all-noise cases).
  3. This noise injection helps overlap p_fake with p_real so their distributions can be compared.

Step 3: The noisy image xt​ is passed through real_unet (frozen) and fake_unet (dynamic). This give the following:

Step 3: Pass-Through real_unet and fake_unet

  1. real_unet produces pred_real_image, representing the generator’s approximation toward the clean image
  2. fake_unet produces pred_fake_image, modeling how the generator’s distribution looks at this timestep.
  3. Comparing pred_real_image and pred_fake_image highlights any discrepancy between real and fake distributions.

Step 4: Compute Loss

  1. Next Mse between x and x — grad is calculated, to extract loss. Here, The term x — grad suggests a corrected output that moves x closer to the real data distribution by reducing the discrepancy measured by grad.

Step 5: Update the fake_distribution (mention update pattern as well)

The fake_unet is updated with diffusion loss between x and pred_fake_image. This is the step where fake unet learns the generator’s evolving fake distribution. The loss is calculated between xt-1_pred and x, unlike other unets which calculate between xt-1_pred and xt-1_gt, such that the fake UNet learns to denoise noisy versions of the generator’s outputs (xt) back to the generator’s current outputs x.

Some Questions

Question 1: Why is the fake_unet updated with the loss calculated as the divergence between xt-1_pred and x0, instead of comparing xt-1_pred with xt-1_gt?

Answer 1: We calculate the divergence between xt-1_pred and x because the fake_unet’s goal is to denoise noisy versions of the generator’s outputs (xt) back to the generator’s current outputs (x). This ensures the fake_unet accurately tracks the generator’s evolving “fake” distribution, providing meaningful gradients to improve the generator’s output.

Question 2: Why do we even need a fake_unet? Couldn’t the KL divergence be computed directly from the pretrained real_unet’s output and the generator’s output?

Answer 2: The generator is intended to produce a fully denoised image in a single step, whereas the pretrained real_unet can only provide partial denoising in that same time frame. This mismatch prevents the real_unet’s output from offering a meaningful KL divergence for training the generator. Instead, the fake_unet is continually trained on the generator’s evolving distribution, enabling it to approximate how the generator’s outputs currently look. As a result, comparing real_unet and fake_unet outputs yields a direction (via an approximate KL divergence) to fine-tune the generator’s probability distribution and improve single-step image synthesis

Distribution Matching Loss

During training the KL divergence measure how much the generator’s distribution deviates from real distribution.

Where Preal is PDF of real data distribution, while Pfake is PDF of fake distribution i.e. the generator (Gθ)

Computing probability densities for high-dimensional datasets is often intractable. For example, a 32×32 pixel grayscale image has 256¹⁰²⁴ dimensions, making direct density computation computationally prohibitive.

Hence, we calculate scores for real and fake distribution.

This allows us to calculate the KL divergence as Sreal moves x toward the modes of Preal, and −Sfake spreads farther from real distribution.

Sreal(x) is the score function of the real data distribution, Sfake(x) is the score function of the fake data distribution. ∇θ Gθ(z) is gradient of the generator’s output (x)

Sreal(x)−Sfake(x) corresponds to the difference between the real and fake scores. Now, Sreal for x being a fake sample will be nearly zero, hence, perturbation is added for the diffusion model to denoised from xt.

Definition for Sfake and Sreal are taken from paper “Song et al. — Score-based generative modeling through stochastic differential equations”

Final Loss

Intuitive Understanding

The outputs of real_unet and fake_unet at timestep t−1 are used to form a gradient that nudges the generator’s current output, x, closer to real_unet’s output at t=0. After this gradient-based shift, a mean squared error (MSE) is computed between the generator’s original output (x) and the shifted version. Essentially, this correction step drives x to align with the real data distribution.

Calculation of loss, as in the code
This Visualizes the loss function at various timesteps, precisely showing how a multi-step generator is used to train one-step generator. Note: This diagram omits details about weighting_factor and makes a few assumptions about the underlying distributions.

The key idea is that the gradient derived from the difference between xfake​ and xreal​ is used to push the generator’s output toward real_unet’s final output at t=0. As the generator learns, its output gradually moves toward the real distribution, pulling fake_unet’s output closer as well. Consequently, the shifted image ∥x−grad∥ also converges toward the real distribution.

To be added:
1. Regularization with regression loss
2. GAN loss for regularization and stability
3. Update pattern of Fake_Unet

Epilogue

I have poured much time and effort into understanding Distribution Matching Distillation (DMD) and translating my intuition into this article. My goal has been to make the concept as clear and accessible as possible. While I’ve done my best to be thorough, I welcome any feedback or suggestions for improvement — I truly believe there’s always room to learn and grow.

--

--

Om Rastogi
Om Rastogi

Written by Om Rastogi

I believe in an altruistic world, where creativity and imagination replace repetitive work

No responses yet