Neural Network Optimization: BatchNorm, Dropout & More
Master essential techniques for training better neural networks: batch normalization, dropout, weight initialization, and learning rate scheduling.
Neural Network Optimization: BatchNorm, Dropout & More
Training neural networks is tricky. Without the right techniques, they don't converge, overfit badly, or take forever to train. Here are the essential tools.
1. Batch Normalization
Problem: Internal covariate shift - each layer's input distribution changes during training, making optimization unstable.
Solution: Normalize inputs to each layer.
import torch.nn as nn
class ModelWithBatchNorm(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.bn1 = nn.BatchNorm1d(256) # Normalize after fc1
self.fc2 = nn.Linear(256, 128)
self.bn2 = nn.BatchNorm1d(128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = self.bn1(torch.relu(self.fc1(x)))
x = self.bn2(torch.relu(self.fc2(x)))
return self.fc3(x)
Benefits:
- Faster training (use higher learning rates)
- Reduces sensitivity to initialization
- Acts as regularization
For CNNs: Use BatchNorm2d after conv layers.
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.bn1 = nn.BatchNorm2d(64) # Match number of channels
2. Dropout
Problem: Overfitting - network memorizes training data.
Solution: Randomly "drop" neurons during training.
class ModelWithDropout(nn.Module):
def __init__(self, dropout_rate=0.5):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.dropout1 = nn.Dropout(dropout_rate)
self.fc2 = nn.Linear(256, 128)
self.dropout2 = nn.Dropout(dropout_rate)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = self.dropout1(torch.relu(self.fc1(x)))
x = self.dropout2(torch.relu(self.fc2(x)))
return self.fc3(x)
Key points:
- Dropout rate 0.5 is common for fully connected layers
- Use 0.2-0.3 for input layer
- Automatically disabled during evaluation (model.eval())
- Don't use dropout after the final layer
3. Weight Initialization
Problem: Bad initialization → vanishing/exploding gradients.
Solution: Initialize weights properly for your activation function.
def init_weights(model):
for m in model.modules():
if isinstance(m, nn.Linear):
# Xavier for tanh/sigmoid
# nn.init.xavier_uniform_(m.weight)
# Kaiming/He for ReLU (recommended)
nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
# Apply to model
model.apply(init_weights)
Rule of thumb:
- ReLU activation → Kaiming/He initialization
- Tanh/Sigmoid → Xavier/Glorot initialization
4. Learning Rate Scheduling
Problem: Fixed learning rate is suboptimal - need high LR early, low LR later.
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ReduceLROnPlateau
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Option 1: Step decay (reduce every N epochs)
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
# Option 2: Cosine annealing (smooth decay)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
# Option 3: Reduce on plateau (reduce when metric stops improving)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=10, factor=0.1)
# Training loop
for epoch in range(epochs):
train_loss = train_one_epoch(model, train_loader)
val_loss = validate(model, val_loader)
# For StepLR or Cosine
scheduler.step()
# For ReduceLROnPlateau
# scheduler.step(val_loss)
print(f"LR: {optimizer.param_groups[0]['lr']}")
5. Gradient Clipping
Problem: Exploding gradients (especially in RNNs).
# Clip by value
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
# Clip by norm (more common)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# In training loop
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
6. Layer Normalization (for Transformers)
BatchNorm doesn't work well for sequences. Use LayerNorm instead:
class TransformerBlock(nn.Module):
def __init__(self, d_model):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, num_heads=8)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(),
nn.Linear(d_model * 4, d_model)
)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Pre-norm architecture
x = x + self.attention(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.ffn(self.norm2(x))
return x
Complete Training Template
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
# Model with all techniques
class OptimizedModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(784, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 10)
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
def forward(self, x):
return self.layers(x)
# Training setup
model = OptimizedModel()
optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
for epoch in range(100):
model.train()
for batch in train_loader:
optimizer.zero_grad()
loss = criterion(model(batch[0]), batch[1])
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
model.eval()
# Validate...
Key Takeaway
These techniques aren't optional extras - they're essential for training deep networks effectively. BatchNorm speeds up training, Dropout prevents overfitting, proper initialization prevents gradient issues, and LR scheduling optimizes convergence. Use them all together for best results.