Intermediate

How Federated Learning Works

Deep dive into the FedAvg algorithm, aggregation strategies, communication efficiency, and the challenges of non-IID data distributions.

Federated Averaging (FedAvg)

FedAvg is the foundational federated learning algorithm proposed by McMahan et al. (2017). It combines local SGD with periodic model averaging:

Python - FedAvg Pseudocode
import numpy as np
import copy

def federated_averaging(global_model, clients_data, rounds, local_epochs, lr):
    for round_num in range(rounds):
        # 1. Select subset of clients
        selected = random.sample(clients_data, k=max(1, int(0.1 * len(clients_data))))

        client_weights = []
        client_sizes = []

        for client in selected:
            # 2. Send global model to client
            local_model = copy.deepcopy(global_model)

            # 3. Train locally
            for epoch in range(local_epochs):
                for batch in client.dataloader:
                    loss = train_step(local_model, batch, lr)

            client_weights.append(local_model.get_weights())
            client_sizes.append(len(client.data))

        # 4. Aggregate: weighted average by data size
        total = sum(client_sizes)
        new_weights = [
            sum(w[i] * (n / total) for w, n in zip(client_weights, client_sizes))
            for i in range(len(client_weights[0]))
        ]
        global_model.set_weights(new_weights)

    return global_model

Aggregation Strategies

StrategyDescriptionWhen to Use
FedAvgWeighted average of model parameters by dataset sizeDefault choice; works well with IID data
FedProxAdds proximal term to keep local models close to globalNon-IID data, heterogeneous systems
FedMAMatches and averages neurons by functionalityWhen model architectures differ
FedBNKeeps batch normalization layers localFeature distribution shift across clients
ScaffoldUses control variates to correct client driftStrongly non-IID data

The Non-IID Challenge

In real-world FL, data across clients is non-IID (not independently and identically distributed). A user who only types in French has very different keyboard data than one who types in English. This causes:

  • Client drift: Local models diverge from each other, making averaging less effective.
  • Slow convergence: More communication rounds are needed to reach good performance.
  • Biased models: The global model may perform poorly on minority data distributions.

Communication Efficiency

Communication is the bottleneck in FL. Techniques to reduce it:

  • Gradient compression: Quantize or sparsify gradients before transmission (e.g., send only top-k% of gradient values).
  • Fewer rounds: Increase local epochs (more local training per round) to reduce the total number of communication rounds.
  • Federated distillation: Share model predictions instead of weights — much smaller payloads.
  • Partial model updates: Only communicate layers that changed significantly.

Client Selection

Not all clients participate in every round. Selection strategies include:

  • Random sampling: Randomly select a fraction (e.g., 10%) of clients per round. Simple and unbiased.
  • Resource-aware: Only select clients that have sufficient battery, are on Wi-Fi, and are idle.
  • Contribution-based: Prioritize clients whose data distributions complement the current model's weaknesses.
Key takeaway: FedAvg is the foundation of federated learning. The main challenges are non-IID data (causing client drift) and communication efficiency. Advanced algorithms like FedProx and Scaffold address non-IID issues, while compression and distillation reduce communication costs.