Variational Autoencoders (VAEs) for Generative Modeling
Learn how VAEs combine autoencoders with probabilistic modeling to generate new data and learn meaningful latent representations.
Variational Autoencoders (VAEs) for Generative Modeling
Regular autoencoders compress data but can't generate new samples. VAEs fix this by learning a probabilistic latent space you can sample from.
Autoencoder vs VAE
Autoencoder:
Input → Encoder → Latent Code (fixed point) → Decoder → Reconstruction
VAE:
Input → Encoder → Latent Distribution (mean, variance) → Sample → Decoder → Reconstruction
The key difference: VAE learns a distribution, not a fixed point.
How VAE Works
- Encoder outputs mean (μ) and variance (σ²) for each latent dimension
- Sample from this distribution using the "reparameterization trick"
- Decoder reconstructs from the sample
- Loss = Reconstruction Loss + KL Divergence
The Reparameterization Trick
Can't backpropagate through random sampling. Solution: sample from N(0,1) and transform:
z = μ + σ × ε where ε ~ N(0,1)
Now gradients flow through μ and σ!
VAE Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAE(nn.Module):
def __init__(self, input_dim=784, latent_dim=20):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU()
)
# Latent space parameters
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)
# Decoder
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, input_dim),
nn.Sigmoid() # For normalized inputs [0,1]
)
def encode(self, x):
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h) # log(σ²) for numerical stability
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar) # σ = exp(0.5 * log(σ²))
eps = torch.randn_like(std) # ε ~ N(0,1)
return mu + std * eps # z = μ + σ × ε
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
reconstruction = self.decode(z)
return reconstruction, mu, logvar
VAE Loss Function
def vae_loss(reconstruction, x, mu, logvar):
# Reconstruction loss (how well we recreate input)
recon_loss = F.binary_cross_entropy(reconstruction, x, reduction='sum')
# KL divergence (how close latent distribution is to N(0,1))
# Formula: -0.5 * sum(1 + log(σ²) - μ² - σ²)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_loss
Why KL divergence? Forces the latent space to be:
- Continuous (similar inputs → similar latent codes)
- Centered around origin (can sample from N(0,1) to generate)
Training VAE
def train_vae(model, train_loader, epochs=50):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(epochs):
total_loss = 0
for batch_x, _ in train_loader:
batch_x = batch_x.view(-1, 784) # Flatten for MNIST
optimizer.zero_grad()
recon, mu, logvar = model(batch_x)
loss = vae_loss(recon, batch_x, mu, logvar)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader.dataset)
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
Generating New Samples
def generate_samples(model, num_samples=16):
model.eval()
with torch.no_grad():
# Sample from standard normal
z = torch.randn(num_samples, 20) # latent_dim = 20
samples = model.decode(z)
return samples
# Generate and visualize
samples = generate_samples(model)
# samples shape: (16, 784) - can reshape to (16, 28, 28) for MNIST
Latent Space Interpolation
def interpolate(model, x1, x2, steps=10):
model.eval()
with torch.no_grad():
# Get latent representations
mu1, _ = model.encode(x1.unsqueeze(0))
mu2, _ = model.encode(x2.unsqueeze(0))
# Interpolate in latent space
interpolations = []
for alpha in torch.linspace(0, 1, steps):
z = (1 - alpha) * mu1 + alpha * mu2
interpolations.append(model.decode(z))
return torch.cat(interpolations)
Convolutional VAE (for Images)
class ConvVAE(nn.Module):
def __init__(self, latent_dim=32):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1), # 28→14
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1), # 14→7
nn.ReLU(),
nn.Flatten()
)
self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
# Decoder
self.fc_decode = nn.Linear(latent_dim, 64 * 7 * 7)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def decode(self, z):
h = self.fc_decode(z)
h = h.view(-1, 64, 7, 7)
return self.decoder(h)
VAE vs GAN
| Aspect | VAE | GAN |
|---|---|---|
| Training | Stable | Can be unstable |
| Sample quality | Blurrier | Sharper |
| Latent space | Structured, interpretable | Less structured |
| Mode coverage | Good (covers all modes) | May miss modes |
| Use case | When you need latent representations | When you need sharp samples |
Key Takeaway
VAEs learn meaningful latent representations by combining autoencoders with probabilistic modeling. The KL divergence term ensures a smooth, continuous latent space you can sample from. Great for generating variations, interpolating between samples, and learning disentangled features. Start with simple VAE, then try β-VAE for better disentanglement.