Federated Learning
Train models without sharing raw data.
Privacy-preserving machine learning.
What is Federated Learning?
Train model across many devices without centralizing data.
**Key idea**: Move code to data, not data to code!
Why Federated Learning?
**Privacy**: Data stays on device **Security**: No central data target **Efficiency**: Use edge computing **Regulation**: GDPR, HIPAA compliance
How It Works
1. Server sends model to devices 2. Devices train locally on their data 3. Devices send updates (not data!) to server 4. Server averages updates 5. Repeat
Simple Example
```python import torch import torch.nn as nn
Central server class Server: def __init__(self, model): self.global_model = model def aggregate(self, client_models): """Average client model weights""" global_dict = self.global_model.state_dict() for key in global_dict.keys(): # Average weights from all clients global_dict[key] = torch.stack([ client.state_dict()[key].float() for client in client_models ]).mean(0) self.global_model.load_state_dict(global_dict) return self.global_model
Client device class Client: def __init__(self, model, data, device_id): self.model = model self.data = data self.device_id = device_id def train(self, epochs=1): """Train on local data""" optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01) for epoch in range(epochs): for x, y in self.data: optimizer.zero_grad() output = self.model(x) loss = nn.CrossEntropyLoss()(output, y) loss.backward() optimizer.step() return self.model
Federated training loop def federated_learning(server, clients, rounds=10): for round in range(rounds): print(f"Round {round+1}/{rounds}") # 1. Server sends model to clients client_models = [] for client in clients: # Copy global model to client client.model.load_state_dict(server.global_model.state_dict()) # Client trains locally trained_model = client.train(epochs=5) client_models.append(trained_model) # 2. Server aggregates client updates server.aggregate(client_models) # 3. Evaluate global model accuracy = evaluate(server.global_model, test_data) print(f"Global accuracy: {accuracy:.2f}%")
Usage model = SimpleCNN() server = Server(model)
Create clients (e.g., user phones) clients = [ Client(copy.deepcopy(model), user1_data, "device_1"), Client(copy.deepcopy(model), user2_data, "device_2"), Client(copy.deepcopy(model), user3_data, "device_3"), ]
federated_learning(server, clients, rounds=50) ```
Using PySyft
```python import syft as sy import torch
Create virtual devices hook = sy.TorchHook(torch) alice = sy.VirtualWorker(hook, id="alice") bob = sy.VirtualWorker(hook, id="bob")
Send data to devices alice_data = data[:50].send(alice) bob_data = data[50:].send(bob)
Define model model = nn.Sequential( nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10) )
Train federally for epoch in range(10): # Train on Alice's device model = model.send(alice) optimizer = torch.optim.SGD(model.parameters(), lr=0.01) for x, y in alice_data: optimizer.zero_grad() output = model(x) loss = F.nll_loss(output, y) loss.backward() optimizer.step() model = model.get() # Retrieve from Alice # Train on Bob's device model = model.send(bob) # ... same training ... model = model.get() ```
Using TensorFlow Federated
```python import tensorflow_federated as tff
Define model def create_keras_model(): return tf.keras.models.Sequential([ tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)), tf.keras.layers.Dense(10, activation='softmax') ])
Convert to TFF model def model_fn(): keras_model = create_keras_model() return tff.learning.from_keras_model( keras_model, input_spec=input_spec, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()] )
Build federated averaging process iterative_process = tff.learning.build_federated_averaging_process(model_fn)
Simulate federated training state = iterative_process.initialize()
for round in range(10): # Each client's dataset federated_data = [client1_data, client2_data, client3_data] # One round of federated training state, metrics = iterative_process.next(state, federated_data) print(f'Round {round}, Metrics: {metrics}') ```
Differential Privacy
Add noise to protect individual privacy:
```python from tensorflow_privacy.privacy.optimizers import dp_optimizer
DP-SGD optimizer optimizer = dp_optimizer.DPAdamGaussianOptimizer( l2_norm_clip=1.0, # Clip gradients noise_multiplier=1.1, # Amount of noise num_microbatches=1, learning_rate=0.001 )
Train with differential privacy model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy') model.fit(X_train, y_train, epochs=10, batch_size=256)
Privacy budget (epsilon) # Lower epsilon = more privacy, less accuracy ```
Secure Aggregation
Encrypt model updates:
```python # Using PySyft's secure aggregation
Each client encrypts their model update encrypted_updates = [] for client in clients: update = client.train() encrypted = encrypt(update) encrypted_updates.append(encrypted)
Server aggregates without seeing individual updates aggregated = secure_aggregate(encrypted_updates) decrypted = decrypt(aggregated)
Server only sees average, not individual contributions ```
Challenges
**Communication cost**: Sending models is expensive - Solution: Compression, quantization
**Data heterogeneity**: Each device has different data - Solution: Personalization layers
**Device availability**: Devices go offline - Solution: Asynchronous updates
**Convergence**: Takes longer than centralized - Solution: More sophisticated aggregation
Real Applications
**Google Keyboard**: Learns typing patterns without seeing your messages
**Apple Siri**: Improves voice recognition privately
**Healthcare**: Train on hospital data without sharing patient records
**IoT**: Learn from sensors without centralizing all data
Advanced Aggregation
```python # Weighted averaging (by dataset size) def weighted_aggregate(client_models, client_sizes): total_size = sum(client_sizes) weights = [size / total_size for size in client_sizes] global_dict = {} for key in client_models[0].state_dict().keys(): global_dict[key] = sum( weight * model.state_dict()[key] for weight, model in zip(weights, client_models) ) return global_dict
FedProx (handles heterogeneity better) def fedprox_train(model, data, global_model, mu=0.01): optimizer = torch.optim.SGD(model.parameters(), lr=0.01) for x, y in data: optimizer.zero_grad() output = model(x) loss = F.cross_entropy(output, y) # Proximal term proximal_term = 0 for param, global_param in zip(model.parameters(), global_model.parameters()): proximal_term += ((param - global_param) ** 2).sum() total_loss = loss + (mu / 2) * proximal_term total_loss.backward() optimizer.step() ```
Best Practices
1. Use secure aggregation 2. Add differential privacy when needed 3. Handle stragglers (slow devices) 4. Monitor convergence carefully 5. Test communication efficiency
Remember
- Federated learning enables privacy-preserving ML - Data never leaves devices - More complex than centralized training - Growing in importance with privacy regulations