Federated Learning
Train models without sharing raw data.
Privacy-preserving machine learning.
What is Federated Learning?
Train model across many devices without centralizing data.
Key idea: Move code to data, not data to code!
Why Federated Learning?
Privacy: Data stays on device
Security: No central data target
Efficiency: Use edge computing
Regulation: GDPR, HIPAA compliance
How It Works
- Server sends model to devices
- Devices train locally on their data
- Devices send updates (not data!) to server
- Server averages updates
- Repeat
Simple Example
import torch
import torch.nn as nn
# Central server
class Server:
def __init__(self, model):
self.global_model = model
def aggregate(self, client_models):
"""Average client model weights"""
global_dict = self.global_model.state_dict()
for key in global_dict.keys():
# Average weights from all clients
global_dict[key] = torch.stack([
client.state_dict()[key].float()
for client in client_models
]).mean(0)
self.global_model.load_state_dict(global_dict)
return self.global_model
# Client device
class Client:
def __init__(self, model, data, device_id):
self.model = model
self.data = data
self.device_id = device_id
def train(self, epochs=1):
"""Train on local data"""
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
for epoch in range(epochs):
for x, y in self.data:
optimizer.zero_grad()
output = self.model(x)
loss = nn.CrossEntropyLoss()(output, y)
loss.backward()
optimizer.step()
return self.model
# Federated training loop
def federated_learning(server, clients, rounds=10):
for round in range(rounds):
print(f"Round {round+1}/{rounds}")
# 1. Server sends model to clients
client_models = []
for client in clients:
# Copy global model to client
client.model.load_state_dict(server.global_model.state_dict())
# Client trains locally
trained_model = client.train(epochs=5)
client_models.append(trained_model)
# 2. Server aggregates client updates
server.aggregate(client_models)
# 3. Evaluate global model
accuracy = evaluate(server.global_model, test_data)
print(f"Global accuracy: {accuracy:.2f}%")
# Usage
model = SimpleCNN()
server = Server(model)
# Create clients (e.g., user phones)
clients = [
Client(copy.deepcopy(model), user1_data, "device_1"),
Client(copy.deepcopy(model), user2_data, "device_2"),
Client(copy.deepcopy(model), user3_data, "device_3"),
]
federated_learning(server, clients, rounds=50)
Using PySyft
import syft as sy
import torch
# Create virtual devices
hook = sy.TorchHook(torch)
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
# Send data to devices
alice_data = data[:50].send(alice)
bob_data = data[50:].send(bob)
# Define model
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
# Train federally
for epoch in range(10):
# Train on Alice's device
model = model.send(alice)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for x, y in alice_data:
optimizer.zero_grad()
output = model(x)
loss = F.nll_loss(output, y)
loss.backward()
optimizer.step()
model = model.get() # Retrieve from Alice
# Train on Bob's device
model = model.send(bob)
# ... same training ...
model = model.get()
Using TensorFlow Federated
import tensorflow_federated as tff
# Define model
def create_keras_model():
return tf.keras.models.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
# Convert to TFF model
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=input_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)
# Build federated averaging process
iterative_process = tff.learning.build_federated_averaging_process(model_fn)
# Simulate federated training
state = iterative_process.initialize()
for round in range(10):
# Each client's dataset
federated_data = [client1_data, client2_data, client3_data]
# One round of federated training
state, metrics = iterative_process.next(state, federated_data)
print(f'Round {round}, Metrics: {metrics}')
Differential Privacy
Add noise to protect individual privacy:
from tensorflow_privacy.privacy.optimizers import dp_optimizer
# DP-SGD optimizer
optimizer = dp_optimizer.DPAdamGaussianOptimizer(
l2_norm_clip=1.0, # Clip gradients
noise_multiplier=1.1, # Amount of noise
num_microbatches=1,
learning_rate=0.001
)
# Train with differential privacy
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy')
model.fit(X_train, y_train, epochs=10, batch_size=256)
# Privacy budget (epsilon)
# Lower epsilon = more privacy, less accuracy
Secure Aggregation
Encrypt model updates:
# Using PySyft's secure aggregation
# Each client encrypts their model update
encrypted_updates = []
for client in clients:
update = client.train()
encrypted = encrypt(update)
encrypted_updates.append(encrypted)
# Server aggregates without seeing individual updates
aggregated = secure_aggregate(encrypted_updates)
decrypted = decrypt(aggregated)
# Server only sees average, not individual contributions
Challenges
Communication cost: Sending models is expensive
- Solution: Compression, quantization
Data heterogeneity: Each device has different data
- Solution: Personalization layers
Device availability: Devices go offline
- Solution: Asynchronous updates
Convergence: Takes longer than centralized
- Solution: More sophisticated aggregation
Real Applications
Google Keyboard: Learns typing patterns without seeing your messages
Apple Siri: Improves voice recognition privately
Healthcare: Train on hospital data without sharing patient records
IoT: Learn from sensors without centralizing all data
Advanced Aggregation
# Weighted averaging (by dataset size)
def weighted_aggregate(client_models, client_sizes):
total_size = sum(client_sizes)
weights = [size / total_size for size in client_sizes]
global_dict = {}
for key in client_models[0].state_dict().keys():
global_dict[key] = sum(
weight * model.state_dict()[key]
for weight, model in zip(weights, client_models)
)
return global_dict
# FedProx (handles heterogeneity better)
def fedprox_train(model, data, global_model, mu=0.01):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for x, y in data:
optimizer.zero_grad()
output = model(x)
loss = F.cross_entropy(output, y)
# Proximal term
proximal_term = 0
for param, global_param in zip(model.parameters(), global_model.parameters()):
proximal_term += ((param - global_param) ** 2).sum()
total_loss = loss + (mu / 2) * proximal_term
total_loss.backward()
optimizer.step()
Best Practices
- Use secure aggregation
- Add differential privacy when needed
- Handle stragglers (slow devices)
- Monitor convergence carefully
- Test communication efficiency
Remember
- Federated learning enables privacy-preserving ML
- Data never leaves devices
- More complex than centralized training
- Growing in importance with privacy regulations