Graph Neural Networks (GNNs): Learning on Graph Data
Learn how GNNs process graph-structured data like social networks, molecules, and knowledge graphs through message passing.
Graph Neural Networks (GNNs): Learning on Graph Data
CNNs work on grids (images). RNNs work on sequences. But what about data with irregular connections? Social networks, molecules, recommendation systems - these are graphs. GNNs are designed for them.
What is Graph Data?
``` Graph = Nodes + Edges
Social Network: - Nodes: People - Edges: Friendships
Molecule: - Nodes: Atoms - Edges: Chemical bonds
Knowledge Graph: - Nodes: Entities - Edges: Relationships ```
Why Can't We Use Regular Neural Networks?
**Problems:** 1. Graphs have variable size and structure 2. No fixed ordering of nodes 3. Need to capture connectivity patterns
**Solution:** Message Passing - nodes exchange information with neighbors.
The Message Passing Framework
``` For each node: 1. AGGREGATE: Collect messages from neighbors 2. UPDATE: Combine with own features 3. Repeat for multiple layers ```
```python # Conceptual pseudocode def message_passing_layer(node_features, edge_index): messages = [] for node in nodes: # Aggregate neighbor features neighbor_features = get_neighbors(node, edge_index) aggregated = aggregate(neighbor_features) # sum, mean, max # Update node new_feature = update(node_features[node], aggregated) messages.append(new_feature) return messages ```
GNN with PyTorch Geometric
```python import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv, global_mean_pool from torch_geometric.data import Data
Create a simple graph edge_index = torch.tensor([ [0, 1, 1, 2], # Source nodes [1, 0, 2, 1] # Target nodes ], dtype=torch.long)
x = torch.tensor([ [1, 0], # Node 0 features [0, 1], # Node 1 features [1, 1] # Node 2 features ], dtype=torch.float)
data = Data(x=x, edge_index=edge_index) ```
Graph Convolutional Network (GCN)
```python from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, out_channels) def forward(self, x, edge_index): # First message passing layer x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, p=0.5, training=self.training) # Second message passing layer x = self.conv2(x, edge_index) return x
For node classification model = GCN(in_channels=2, hidden_channels=16, out_channels=3) out = model(data.x, data.edge_index) # Shape: [num_nodes, 3] ```
GraphSAGE: Scalable GNNs
GCN requires the full graph. GraphSAGE samples neighbors, making it scalable:
```python from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv(in_channels, hidden_channels) self.conv2 = SAGEConv(hidden_channels, out_channels) def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) x = self.conv2(x, edge_index) return x ```
Graph Attention Networks (GAT)
Not all neighbors are equally important. GAT learns attention weights:
```python from torch_geometric.nn import GATConv
class GAT(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, heads=8): super().__init__() self.conv1 = GATConv(in_channels, hidden_channels, heads=heads) self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1) def forward(self, x, edge_index): x = F.elu(self.conv1(x, edge_index)) x = F.dropout(x, p=0.6, training=self.training) x = self.conv2(x, edge_index) return x ```
Graph-Level Tasks (Graph Classification)
Sometimes you need to classify entire graphs, not nodes:
```python from torch_geometric.nn import global_mean_pool, global_max_pool
class GraphClassifier(torch.nn.Module): def __init__(self, in_channels, hidden_channels, num_classes): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, hidden_channels) self.classifier = torch.nn.Linear(hidden_channels, num_classes) def forward(self, x, edge_index, batch): # Node embeddings x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) # Aggregate to graph-level x = global_mean_pool(x, batch) # Average all nodes per graph # Classify return self.classifier(x) ```
Link Prediction
Predict if an edge should exist between two nodes:
```python def link_prediction(model, data, node_i, node_j): # Get node embeddings embeddings = model(data.x, data.edge_index) # Score the potential edge emb_i = embeddings[node_i] emb_j = embeddings[node_j] # Simple: dot product score = torch.dot(emb_i, emb_j) probability = torch.sigmoid(score) return probability ```
Training a GNN
```python from torch_geometric.datasets import Planetoid
Load Cora dataset (citation network) dataset = Planetoid(root='data', name='Cora') data = dataset[0]
model = GCN( in_channels=dataset.num_features, hidden_channels=16, out_channels=dataset.num_classes ) optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train(): model.train() optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return loss.item()
def test(): model.eval() out = model(data.x, data.edge_index) pred = out.argmax(dim=1) correct = pred[data.test_mask] == data.y[data.test_mask] return correct.sum() / data.test_mask.sum()
for epoch in range(200): loss = train() if epoch % 20 == 0: acc = test() print(f"Epoch {epoch}, Loss: {loss:.4f}, Test Acc: {acc:.4f}") ```
GNN Applications
| Application | Node Task | Edge Task | Graph Task | |-------------|-----------|-----------|------------| | Social Networks | User classification | Friend recommendation | Community detection | | Molecules | Atom property | Bond prediction | Drug discovery | | Knowledge Graphs | Entity classification | Link prediction | - | | Traffic | Road segment | - | Traffic prediction |
Key Takeaway
GNNs extend deep learning to graph-structured data through message passing. Nodes aggregate information from neighbors, learning representations that capture graph structure. Start with GCN for simplicity, use GraphSAGE for large graphs, and GAT when attention matters. PyTorch Geometric makes implementation straightforward.