ML12 min read

Graph Neural Networks (GNNs): Learning on Graph Data

Learn how GNNs process graph-structured data like social networks, molecules, and knowledge graphs through message passing.

Sarah Chen
December 19, 2025
0.0k0

Graph Neural Networks (GNNs): Learning on Graph Data

CNNs work on grids (images). RNNs work on sequences. But what about data with irregular connections? Social networks, molecules, recommendation systems - these are graphs. GNNs are designed for them.

What is Graph Data?

Graph = Nodes + Edges

Social Network:
- Nodes: People
- Edges: Friendships

Molecule:
- Nodes: Atoms
- Edges: Chemical bonds

Knowledge Graph:
- Nodes: Entities
- Edges: Relationships

Why Can't We Use Regular Neural Networks?

Problems:

  1. Graphs have variable size and structure
  2. No fixed ordering of nodes
  3. Need to capture connectivity patterns

Solution: Message Passing - nodes exchange information with neighbors.

The Message Passing Framework

For each node:
1. AGGREGATE: Collect messages from neighbors
2. UPDATE: Combine with own features
3. Repeat for multiple layers
# Conceptual pseudocode
def message_passing_layer(node_features, edge_index):
    messages = []
    for node in nodes:
        # Aggregate neighbor features
        neighbor_features = get_neighbors(node, edge_index)
        aggregated = aggregate(neighbor_features)  # sum, mean, max
        
        # Update node
        new_feature = update(node_features[node], aggregated)
        messages.append(new_feature)
    
    return messages

GNN with PyTorch Geometric

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data

# Create a simple graph
edge_index = torch.tensor([
    [0, 1, 1, 2],  # Source nodes
    [1, 0, 2, 1]   # Target nodes
], dtype=torch.long)

x = torch.tensor([
    [1, 0],  # Node 0 features
    [0, 1],  # Node 1 features
    [1, 1]   # Node 2 features
], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)

Graph Convolutional Network (GCN)

from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        # First message passing layer
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        
        # Second message passing layer
        x = self.conv2(x, edge_index)
        return x

# For node classification
model = GCN(in_channels=2, hidden_channels=16, out_channels=3)
out = model(data.x, data.edge_index)  # Shape: [num_nodes, 3]

GraphSAGE: Scalable GNNs

GCN requires the full graph. GraphSAGE samples neighbors, making it scalable:

from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

Graph Attention Networks (GAT)

Not all neighbors are equally important. GAT learns attention weights:

from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
    
    def forward(self, x, edge_index):
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

Graph-Level Tasks (Graph Classification)

Sometimes you need to classify entire graphs, not nodes:

from torch_geometric.nn import global_mean_pool, global_max_pool

class GraphClassifier(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.classifier = torch.nn.Linear(hidden_channels, num_classes)
    
    def forward(self, x, edge_index, batch):
        # Node embeddings
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        
        # Aggregate to graph-level
        x = global_mean_pool(x, batch)  # Average all nodes per graph
        
        # Classify
        return self.classifier(x)

Link Prediction

Predict if an edge should exist between two nodes:

def link_prediction(model, data, node_i, node_j):
    # Get node embeddings
    embeddings = model(data.x, data.edge_index)
    
    # Score the potential edge
    emb_i = embeddings[node_i]
    emb_j = embeddings[node_j]
    
    # Simple: dot product
    score = torch.dot(emb_i, emb_j)
    probability = torch.sigmoid(score)
    
    return probability

Training a GNN

from torch_geometric.datasets import Planetoid

# Load Cora dataset (citation network)
dataset = Planetoid(root='data', name='Cora')
data = dataset[0]

model = GCN(
    in_channels=dataset.num_features,
    hidden_channels=16,
    out_channels=dataset.num_classes
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test():
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    correct = pred[data.test_mask] == data.y[data.test_mask]
    return correct.sum() / data.test_mask.sum()

for epoch in range(200):
    loss = train()
    if epoch % 20 == 0:
        acc = test()
        print(f"Epoch {epoch}, Loss: {loss:.4f}, Test Acc: {acc:.4f}")

GNN Applications

Application Node Task Edge Task Graph Task
Social Networks User classification Friend recommendation Community detection
Molecules Atom property Bond prediction Drug discovery
Knowledge Graphs Entity classification Link prediction -
Traffic Road segment - Traffic prediction

Key Takeaway

GNNs extend deep learning to graph-structured data through message passing. Nodes aggregate information from neighbors, learning representations that capture graph structure. Start with GCN for simplicity, use GraphSAGE for large graphs, and GAT when attention matters. PyTorch Geometric makes implementation straightforward.

#Machine Learning#Deep Learning#GNN#Graph Neural Networks#Advanced