ML12 min read

Deep Q-Networks (DQN): Deep Reinforcement Learning

Learn how DQN combines deep learning with Q-learning to solve complex environments like video games.

Sarah Chen
December 19, 2025
0.0k0

Deep Q-Networks (DQN): Deep Reinforcement Learning

Q-Learning works great for small problems. But what if you have millions of states? You can't store a Q-table that big. DQN solves this by using a neural network to approximate Q-values.

The Problem with Tabular Q-Learning

Atari Game State Space:
- Screen: 210 x 160 pixels
- Colors: 128 possible values
- States: 128^(210×160) ≈ Astronomical

Q-Table for this? Impossible.

The DQN Solution

Instead of a table, use a neural network:

Q-Table: state → lookup → Q-values
DQN:     state → neural network → Q-values

The network learns to generalize across similar states.

Basic DQN Architecture

import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_size)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)  # Q-values for each action

# For image inputs (like Atari)
class ConvDQN(nn.Module):
    def __init__(self, action_size):
        super(ConvDQN, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc1 = nn.Linear(7 * 7 * 64, 512)
        self.fc2 = nn.Linear(512, action_size)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

Two Key Innovations

1. Experience Replay

Don't learn from experiences immediately. Store them and sample randomly:

from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (torch.stack(states), torch.tensor(actions),
                torch.tensor(rewards), torch.stack(next_states),
                torch.tensor(dones))
    
    def __len__(self):
        return len(self.buffer)

Why? Breaks correlation between consecutive samples. More stable learning.

2. Target Network

Use a separate network for target Q-values:

# Main network (updated every step)
policy_net = DQN(state_size, action_size)

# Target network (updated periodically)
target_net = DQN(state_size, action_size)
target_net.load_state_dict(policy_net.state_dict())

# Update target network every N steps
def update_target(policy_net, target_net):
    target_net.load_state_dict(policy_net.state_dict())

Why? Prevents "chasing a moving target" - stabilizes training.

Complete Training Loop

import torch.optim as optim

def train_dqn(env, episodes=1000):
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    
    policy_net = DQN(state_size, action_size)
    target_net = DQN(state_size, action_size)
    target_net.load_state_dict(policy_net.state_dict())
    
    optimizer = optim.Adam(policy_net.parameters(), lr=0.001)
    memory = ReplayBuffer(10000)
    
    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.995
    gamma = 0.99
    batch_size = 64
    target_update = 10
    
    for episode in range(episodes):
        state = torch.tensor(env.reset(), dtype=torch.float32)
        total_reward = 0
        
        while True:
            # Epsilon-greedy action selection
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    action = policy_net(state).argmax().item()
            
            next_state, reward, done, _ = env.step(action)
            next_state = torch.tensor(next_state, dtype=torch.float32)
            
            memory.push(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward
            
            # Train when enough samples
            if len(memory) >= batch_size:
                train_step(policy_net, target_net, memory, 
                          optimizer, batch_size, gamma)
            
            if done:
                break
        
        # Update target network
        if episode % target_update == 0:
            target_net.load_state_dict(policy_net.state_dict())
        
        # Decay epsilon
        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        
        print(f"Episode {episode}, Reward: {total_reward:.0f}, Epsilon: {epsilon:.2f}")

def train_step(policy_net, target_net, memory, optimizer, batch_size, gamma):
    states, actions, rewards, next_states, dones = memory.sample(batch_size)
    
    # Current Q values
    current_q = policy_net(states).gather(1, actions.unsqueeze(1))
    
    # Target Q values
    with torch.no_grad():
        max_next_q = target_net(next_states).max(1)[0]
        target_q = rewards + gamma * max_next_q * (1 - dones.float())
    
    # Loss and update
    loss = nn.MSELoss()(current_q.squeeze(), target_q)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

DQN Improvements

Variant Improvement
Double DQN Reduces overestimation of Q-values
Dueling DQN Separates state value and advantage
Prioritized Replay Samples important experiences more
Rainbow Combines all improvements

Key Takeaway

DQN enables reinforcement learning on complex, high-dimensional problems by using neural networks to approximate Q-values. The two key tricks - experience replay and target networks - make training stable. Start with basic DQN, then try Double DQN for better performance.

#Machine Learning#Deep Learning#Reinforcement Learning#DQN#Advanced