Graph Neural Networks (GNNs): Learning on Graph Data
Learn how GNNs process graph-structured data like social networks, molecules, and knowledge graphs through message passing.
Graph Neural Networks (GNNs): Learning on Graph Data
CNNs work on grids (images). RNNs work on sequences. But what about data with irregular connections? Social networks, molecules, recommendation systems - these are graphs. GNNs are designed for them.
What is Graph Data?
Graph = Nodes + Edges
Social Network:
- Nodes: People
- Edges: Friendships
Molecule:
- Nodes: Atoms
- Edges: Chemical bonds
Knowledge Graph:
- Nodes: Entities
- Edges: Relationships
Why Can't We Use Regular Neural Networks?
Problems:
- Graphs have variable size and structure
- No fixed ordering of nodes
- Need to capture connectivity patterns
Solution: Message Passing - nodes exchange information with neighbors.
The Message Passing Framework
For each node:
1. AGGREGATE: Collect messages from neighbors
2. UPDATE: Combine with own features
3. Repeat for multiple layers
# Conceptual pseudocode
def message_passing_layer(node_features, edge_index):
messages = []
for node in nodes:
# Aggregate neighbor features
neighbor_features = get_neighbors(node, edge_index)
aggregated = aggregate(neighbor_features) # sum, mean, max
# Update node
new_feature = update(node_features[node], aggregated)
messages.append(new_feature)
return messages
GNN with PyTorch Geometric
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data
# Create a simple graph
edge_index = torch.tensor([
[0, 1, 1, 2], # Source nodes
[1, 0, 2, 1] # Target nodes
], dtype=torch.long)
x = torch.tensor([
[1, 0], # Node 0 features
[0, 1], # Node 1 features
[1, 1] # Node 2 features
], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
Graph Convolutional Network (GCN)
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
# First message passing layer
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
# Second message passing layer
x = self.conv2(x, edge_index)
return x
# For node classification
model = GCN(in_channels=2, hidden_channels=16, out_channels=3)
out = model(data.x, data.edge_index) # Shape: [num_nodes, 3]
GraphSAGE: Scalable GNNs
GCN requires the full graph. GraphSAGE samples neighbors, making it scalable:
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return x
Graph Attention Networks (GAT)
Not all neighbors are equally important. GAT learns attention weights:
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
def forward(self, x, edge_index):
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
Graph-Level Tasks (Graph Classification)
Sometimes you need to classify entire graphs, not nodes:
from torch_geometric.nn import global_mean_pool, global_max_pool
class GraphClassifier(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, num_classes):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.classifier = torch.nn.Linear(hidden_channels, num_classes)
def forward(self, x, edge_index, batch):
# Node embeddings
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
# Aggregate to graph-level
x = global_mean_pool(x, batch) # Average all nodes per graph
# Classify
return self.classifier(x)
Link Prediction
Predict if an edge should exist between two nodes:
def link_prediction(model, data, node_i, node_j):
# Get node embeddings
embeddings = model(data.x, data.edge_index)
# Score the potential edge
emb_i = embeddings[node_i]
emb_j = embeddings[node_j]
# Simple: dot product
score = torch.dot(emb_i, emb_j)
probability = torch.sigmoid(score)
return probability
Training a GNN
from torch_geometric.datasets import Planetoid
# Load Cora dataset (citation network)
dataset = Planetoid(root='data', name='Cora')
data = dataset[0]
model = GCN(
in_channels=dataset.num_features,
hidden_channels=16,
out_channels=dataset.num_classes
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
def test():
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
correct = pred[data.test_mask] == data.y[data.test_mask]
return correct.sum() / data.test_mask.sum()
for epoch in range(200):
loss = train()
if epoch % 20 == 0:
acc = test()
print(f"Epoch {epoch}, Loss: {loss:.4f}, Test Acc: {acc:.4f}")
GNN Applications
| Application | Node Task | Edge Task | Graph Task |
|---|---|---|---|
| Social Networks | User classification | Friend recommendation | Community detection |
| Molecules | Atom property | Bond prediction | Drug discovery |
| Knowledge Graphs | Entity classification | Link prediction | - |
| Traffic | Road segment | - | Traffic prediction |
Key Takeaway
GNNs extend deep learning to graph-structured data through message passing. Nodes aggregate information from neighbors, learning representations that capture graph structure. Start with GCN for simplicity, use GraphSAGE for large graphs, and GAT when attention matters. PyTorch Geometric makes implementation straightforward.