Intermediate
Federated Learning Frameworks
Hands-on with PySyft, TensorFlow Federated, Flower, and NVIDIA FLARE — the leading frameworks for building federated learning systems.
Framework Comparison
| Framework | ML Backend | Best For | Production Ready |
|---|---|---|---|
| Flower | Any (PyTorch, TF, etc.) | Research and production, any ML framework | Yes |
| TFF | TensorFlow | Google-scale simulation and research | Research-focused |
| PySyft | PyTorch | Privacy-preserving ML with advanced crypto | Growing |
| NVIDIA FLARE | Any | Enterprise FL, healthcare, cross-silo | Yes |
Flower (flwr)
Flower is the most popular FL framework due to its framework-agnostic design. It works with PyTorch, TensorFlow, JAX, or any other ML framework:
Python - Flower FL with PyTorch
import flwr as fl import torch import torch.nn as nn from torch.utils.data import DataLoader # Define a simple model class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(784, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): return self.fc2(torch.relu(self.fc1(x))) # Define Flower client class FlowerClient(fl.client.NumPyClient): def get_parameters(self, config): return [val.numpy() for val in self.model.state_dict().values()] def fit(self, parameters, config): # Update model with global parameters set_parameters(self.model, parameters) # Train locally train(self.model, self.trainloader, epochs=1) return self.get_parameters(config), len(self.trainloader.dataset), {} def evaluate(self, parameters, config): set_parameters(self.model, parameters) loss, accuracy = test(self.model, self.testloader) return loss, len(self.testloader.dataset), {"accuracy": accuracy} # Start Flower client fl.client.start_numpy_client( server_address="localhost:8080", client=FlowerClient() )
TensorFlow Federated (TFF)
TFF provides two layers: a high-level Federated Learning API and a low-level Federated Core API for custom federated algorithms:
Python - TensorFlow Federated
import tensorflow_federated as tff import tensorflow as tf # Load federated dataset (e.g., EMNIST by writer) emnist = tff.simulation.datasets.emnist.load_data() train_data, test_data = emnist # Define model function def create_keras_model(): return tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28, 1)), tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10, activation="softmax"), ]) # Wrap for TFF def model_fn(): return tff.learning.models.from_keras_model( create_keras_model(), input_spec=train_data.element_type_structure, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()] ) # Build federated averaging process trainer = tff.learning.algorithms.build_weighted_fed_avg( model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02), server_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0) ) # Run federated training state = trainer.initialize() for round_num in range(10): result = trainer.next(state, train_datasets) state = result.state print(f"Round {round_num}: {result.metrics}")
PySyft
PySyft focuses on privacy-preserving ML, combining federated learning with advanced cryptographic techniques like secure multi-party computation and homomorphic encryption.
- Remote execution: Send computation to data, not data to computation.
- Privacy budget: Track and enforce differential privacy budgets across queries.
- Data governance: Data owners control who can access their data and for what purpose.
NVIDIA FLARE
NVIDIA Federated Learning Application Runtime Environment (FLARE) is designed for enterprise cross-silo FL, particularly in healthcare:
- Privacy: Built-in differential privacy and homomorphic encryption.
- Reliability: Handles client disconnections, checkpoint recovery, and fault tolerance.
- Healthcare focus: Pre-built workflows for medical imaging and electronic health records.
Key takeaway: Flower is the recommended starting point for most FL projects due to its framework-agnostic design and active community. TFF is ideal for TensorFlow-based research and simulation. PySyft excels at privacy-preserving computation, and NVIDIA FLARE targets enterprise healthcare deployments.
Lilly Tech Systems