Intermediate
How Federated Learning Works
Deep dive into the FedAvg algorithm, aggregation strategies, communication efficiency, and the challenges of non-IID data distributions.
Federated Averaging (FedAvg)
FedAvg is the foundational federated learning algorithm proposed by McMahan et al. (2017). It combines local SGD with periodic model averaging:
Python - FedAvg Pseudocode
import numpy as np import copy def federated_averaging(global_model, clients_data, rounds, local_epochs, lr): for round_num in range(rounds): # 1. Select subset of clients selected = random.sample(clients_data, k=max(1, int(0.1 * len(clients_data)))) client_weights = [] client_sizes = [] for client in selected: # 2. Send global model to client local_model = copy.deepcopy(global_model) # 3. Train locally for epoch in range(local_epochs): for batch in client.dataloader: loss = train_step(local_model, batch, lr) client_weights.append(local_model.get_weights()) client_sizes.append(len(client.data)) # 4. Aggregate: weighted average by data size total = sum(client_sizes) new_weights = [ sum(w[i] * (n / total) for w, n in zip(client_weights, client_sizes)) for i in range(len(client_weights[0])) ] global_model.set_weights(new_weights) return global_model
Aggregation Strategies
| Strategy | Description | When to Use |
|---|---|---|
| FedAvg | Weighted average of model parameters by dataset size | Default choice; works well with IID data |
| FedProx | Adds proximal term to keep local models close to global | Non-IID data, heterogeneous systems |
| FedMA | Matches and averages neurons by functionality | When model architectures differ |
| FedBN | Keeps batch normalization layers local | Feature distribution shift across clients |
| Scaffold | Uses control variates to correct client drift | Strongly non-IID data |
The Non-IID Challenge
In real-world FL, data across clients is non-IID (not independently and identically distributed). A user who only types in French has very different keyboard data than one who types in English. This causes:
- Client drift: Local models diverge from each other, making averaging less effective.
- Slow convergence: More communication rounds are needed to reach good performance.
- Biased models: The global model may perform poorly on minority data distributions.
Communication Efficiency
Communication is the bottleneck in FL. Techniques to reduce it:
- Gradient compression: Quantize or sparsify gradients before transmission (e.g., send only top-k% of gradient values).
- Fewer rounds: Increase local epochs (more local training per round) to reduce the total number of communication rounds.
- Federated distillation: Share model predictions instead of weights — much smaller payloads.
- Partial model updates: Only communicate layers that changed significantly.
Client Selection
Not all clients participate in every round. Selection strategies include:
- Random sampling: Randomly select a fraction (e.g., 10%) of clients per round. Simple and unbiased.
- Resource-aware: Only select clients that have sufficient battery, are on Wi-Fi, and are idle.
- Contribution-based: Prioritize clients whose data distributions complement the current model's weaknesses.
Key takeaway: FedAvg is the foundation of federated learning. The main challenges are non-IID data (causing client drift) and communication efficiency. Advanced algorithms like FedProx and Scaffold address non-IID issues, while compression and distillation reduce communication costs.
Lilly Tech Systems