ML11 min read

Variational Autoencoders (VAEs) for Generative Modeling

Learn how VAEs combine autoencoders with probabilistic modeling to generate new data and learn meaningful latent representations.

Sarah Chen
December 19, 2025
0.0k0

Variational Autoencoders (VAEs) for Generative Modeling

Regular autoencoders compress data but can't generate new samples. VAEs fix this by learning a probabilistic latent space you can sample from.

Autoencoder vs VAE

Autoencoder:

Input → Encoder → Latent Code (fixed point) → Decoder → Reconstruction

VAE:

Input → Encoder → Latent Distribution (mean, variance) → Sample → Decoder → Reconstruction

The key difference: VAE learns a distribution, not a fixed point.

How VAE Works

  1. Encoder outputs mean (μ) and variance (σ²) for each latent dimension
  2. Sample from this distribution using the "reparameterization trick"
  3. Decoder reconstructs from the sample
  4. Loss = Reconstruction Loss + KL Divergence

The Reparameterization Trick

Can't backpropagate through random sampling. Solution: sample from N(0,1) and transform:

z = μ + σ × ε    where ε ~ N(0,1)

Now gradients flow through μ and σ!

VAE Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, input_dim=784, latent_dim=20):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        )
        
        # Latent space parameters
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim),
            nn.Sigmoid()  # For normalized inputs [0,1]
        )
    
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)  # log(σ²) for numerical stability
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # σ = exp(0.5 * log(σ²))
        eps = torch.randn_like(std)    # ε ~ N(0,1)
        return mu + std * eps          # z = μ + σ × ε
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstruction = self.decode(z)
        return reconstruction, mu, logvar

VAE Loss Function

def vae_loss(reconstruction, x, mu, logvar):
    # Reconstruction loss (how well we recreate input)
    recon_loss = F.binary_cross_entropy(reconstruction, x, reduction='sum')
    
    # KL divergence (how close latent distribution is to N(0,1))
    # Formula: -0.5 * sum(1 + log(σ²) - μ² - σ²)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + kl_loss

Why KL divergence? Forces the latent space to be:

  • Continuous (similar inputs → similar latent codes)
  • Centered around origin (can sample from N(0,1) to generate)

Training VAE

def train_vae(model, train_loader, epochs=50):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_x, _ in train_loader:
            batch_x = batch_x.view(-1, 784)  # Flatten for MNIST
            
            optimizer.zero_grad()
            recon, mu, logvar = model(batch_x)
            loss = vae_loss(recon, batch_x, mu, logvar)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

Generating New Samples

def generate_samples(model, num_samples=16):
    model.eval()
    with torch.no_grad():
        # Sample from standard normal
        z = torch.randn(num_samples, 20)  # latent_dim = 20
        samples = model.decode(z)
    return samples

# Generate and visualize
samples = generate_samples(model)
# samples shape: (16, 784) - can reshape to (16, 28, 28) for MNIST

Latent Space Interpolation

def interpolate(model, x1, x2, steps=10):
    model.eval()
    with torch.no_grad():
        # Get latent representations
        mu1, _ = model.encode(x1.unsqueeze(0))
        mu2, _ = model.encode(x2.unsqueeze(0))
        
        # Interpolate in latent space
        interpolations = []
        for alpha in torch.linspace(0, 1, steps):
            z = (1 - alpha) * mu1 + alpha * mu2
            interpolations.append(model.decode(z))
        
        return torch.cat(interpolations)

Convolutional VAE (for Images)

class ConvVAE(nn.Module):
    def __init__(self, latent_dim=32):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),  # 28→14
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), # 14→7
            nn.ReLU(),
            nn.Flatten()
        )
        
        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
        
        # Decoder
        self.fc_decode = nn.Linear(latent_dim, 64 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
    
    def decode(self, z):
        h = self.fc_decode(z)
        h = h.view(-1, 64, 7, 7)
        return self.decoder(h)

VAE vs GAN

Aspect VAE GAN
Training Stable Can be unstable
Sample quality Blurrier Sharper
Latent space Structured, interpretable Less structured
Mode coverage Good (covers all modes) May miss modes
Use case When you need latent representations When you need sharp samples

Key Takeaway

VAEs learn meaningful latent representations by combining autoencoders with probabilistic modeling. The KL divergence term ensures a smooth, continuous latent space you can sample from. Great for generating variations, interpolating between samples, and learning disentangled features. Start with simple VAE, then try β-VAE for better disentanglement.

#Machine Learning#Deep Learning#VAE#Generative Models#Advanced