Deep Q-Networks (DQN): Deep Reinforcement Learning
Learn how DQN combines deep learning with Q-learning to solve complex environments like video games.
Deep Q-Networks (DQN): Deep Reinforcement Learning
Q-Learning works great for small problems. But what if you have millions of states? You can't store a Q-table that big. DQN solves this by using a neural network to approximate Q-values.
The Problem with Tabular Q-Learning
Atari Game State Space:
- Screen: 210 x 160 pixels
- Colors: 128 possible values
- States: 128^(210×160) ≈ Astronomical
Q-Table for this? Impossible.
The DQN Solution
Instead of a table, use a neural network:
Q-Table: state → lookup → Q-values
DQN: state → neural network → Q-values
The network learns to generalize across similar states.
Basic DQN Architecture
import torch
import torch.nn as nn
class DQN(nn.Module):
def __init__(self, state_size, action_size):
super(DQN, self).__init__()
self.fc1 = nn.Linear(state_size, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, action_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x) # Q-values for each action
# For image inputs (like Atari)
class ConvDQN(nn.Module):
def __init__(self, action_size):
super(ConvDQN, self).__init__()
self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
self.fc1 = nn.Linear(7 * 7 * 64, 512)
self.fc2 = nn.Linear(512, action_size)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = torch.relu(self.conv3(x))
x = x.view(x.size(0), -1)
x = torch.relu(self.fc1(x))
return self.fc2(x)
Two Key Innovations
1. Experience Replay
Don't learn from experiences immediately. Store them and sample randomly:
from collections import deque
import random
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (torch.stack(states), torch.tensor(actions),
torch.tensor(rewards), torch.stack(next_states),
torch.tensor(dones))
def __len__(self):
return len(self.buffer)
Why? Breaks correlation between consecutive samples. More stable learning.
2. Target Network
Use a separate network for target Q-values:
# Main network (updated every step)
policy_net = DQN(state_size, action_size)
# Target network (updated periodically)
target_net = DQN(state_size, action_size)
target_net.load_state_dict(policy_net.state_dict())
# Update target network every N steps
def update_target(policy_net, target_net):
target_net.load_state_dict(policy_net.state_dict())
Why? Prevents "chasing a moving target" - stabilizes training.
Complete Training Loop
import torch.optim as optim
def train_dqn(env, episodes=1000):
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
policy_net = DQN(state_size, action_size)
target_net = DQN(state_size, action_size)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.Adam(policy_net.parameters(), lr=0.001)
memory = ReplayBuffer(10000)
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.995
gamma = 0.99
batch_size = 64
target_update = 10
for episode in range(episodes):
state = torch.tensor(env.reset(), dtype=torch.float32)
total_reward = 0
while True:
# Epsilon-greedy action selection
if random.random() < epsilon:
action = env.action_space.sample()
else:
with torch.no_grad():
action = policy_net(state).argmax().item()
next_state, reward, done, _ = env.step(action)
next_state = torch.tensor(next_state, dtype=torch.float32)
memory.push(state, action, reward, next_state, done)
state = next_state
total_reward += reward
# Train when enough samples
if len(memory) >= batch_size:
train_step(policy_net, target_net, memory,
optimizer, batch_size, gamma)
if done:
break
# Update target network
if episode % target_update == 0:
target_net.load_state_dict(policy_net.state_dict())
# Decay epsilon
epsilon = max(epsilon_min, epsilon * epsilon_decay)
print(f"Episode {episode}, Reward: {total_reward:.0f}, Epsilon: {epsilon:.2f}")
def train_step(policy_net, target_net, memory, optimizer, batch_size, gamma):
states, actions, rewards, next_states, dones = memory.sample(batch_size)
# Current Q values
current_q = policy_net(states).gather(1, actions.unsqueeze(1))
# Target Q values
with torch.no_grad():
max_next_q = target_net(next_states).max(1)[0]
target_q = rewards + gamma * max_next_q * (1 - dones.float())
# Loss and update
loss = nn.MSELoss()(current_q.squeeze(), target_q)
optimizer.zero_grad()
loss.backward()
optimizer.step()
DQN Improvements
| Variant | Improvement |
|---|---|
| Double DQN | Reduces overestimation of Q-values |
| Dueling DQN | Separates state value and advantage |
| Prioritized Replay | Samples important experiences more |
| Rainbow | Combines all improvements |
Key Takeaway
DQN enables reinforcement learning on complex, high-dimensional problems by using neural networks to approximate Q-values. The two key tricks - experience replay and target networks - make training stable. Start with basic DQN, then try Double DQN for better performance.