Edge-Cloud Synchronization
Edge devices are not islands. Models need updates, devices need monitoring, and the data collected at the edge is gold for retraining. This lesson covers the production patterns for keeping edge devices and cloud in sync — OTA model distribution, data collection pipelines, federated learning, and bandwidth management.
OTA Model Update Distribution
Pushing a new model to 10,000 devices reliably is harder than training the model. Here is a production OTA update system:
import hashlib
import json
import aiohttp
import asyncio
from pathlib import Path
from dataclasses import dataclass
from enum import Enum
class UpdateStatus(Enum):
PENDING = "pending"
DOWNLOADING = "downloading"
VALIDATING = "validating"
APPLYING = "applying"
ACTIVE = "active"
ROLLED_BACK = "rolled_back"
FAILED = "failed"
@dataclass
class ModelManifest:
"""Describes a model version for OTA distribution."""
model_id: str
version: str
url: str # CDN URL for download
sha256: str # Integrity check
size_bytes: int
min_device_ram_mb: int
min_firmware_version: str
rollout_percentage: int # 0-100, for staged rollouts
class OTAUpdateClient:
"""
Runs on each edge device. Checks for updates, downloads,
validates, and applies new models with automatic rollback.
"""
def __init__(self, device_id: str, server_url: str, model_dir: str):
self.device_id = device_id
self.server_url = server_url
self.model_dir = Path(model_dir)
self.current_version = self._load_current_version()
async def check_for_update(self) -> ModelManifest | None:
"""Poll server for available updates."""
async with aiohttp.ClientSession() as session:
resp = await session.get(
f"{self.server_url}/api/updates/check",
params={
"device_id": self.device_id,
"current_version": self.current_version,
"device_ram_mb": self._get_ram_mb(),
"firmware_version": self._get_firmware_version(),
}
)
data = await resp.json()
if data.get("update_available"):
return ModelManifest(**data["manifest"])
return None
async def apply_update(self, manifest: ModelManifest) -> bool:
"""Download, validate, and swap to new model."""
new_path = self.model_dir / f"model_v{manifest.version}.tflite"
old_path = self.model_dir / "model_current.tflite"
backup_path = self.model_dir / "model_backup.tflite"
try:
# Step 1: Download with progress tracking
await self._report_status(UpdateStatus.DOWNLOADING)
await self._download_with_resume(manifest.url, new_path, manifest.size_bytes)
# Step 2: Validate integrity
await self._report_status(UpdateStatus.VALIDATING)
actual_hash = self._compute_sha256(new_path)
if actual_hash != manifest.sha256:
raise ValueError(f"Hash mismatch: {actual_hash} != {manifest.sha256}")
# Step 3: Validate model loads and produces output
test_output = self._test_model(new_path)
if test_output is None:
raise ValueError("Model failed validation inference")
# Step 4: Atomic swap (backup old, activate new)
await self._report_status(UpdateStatus.APPLYING)
if old_path.exists():
old_path.rename(backup_path) # Keep backup for rollback
new_path.rename(old_path)
self.current_version = manifest.version
await self._report_status(UpdateStatus.ACTIVE)
return True
except Exception as e:
# Rollback: restore backup
if backup_path.exists() and not old_path.exists():
backup_path.rename(old_path)
await self._report_status(UpdateStatus.FAILED, error=str(e))
return False
async def _download_with_resume(self, url: str, path: Path, total_size: int):
"""Download with resume support for unreliable connections."""
existing_size = path.stat().st_size if path.exists() else 0
headers = {}
if existing_size > 0:
headers["Range"] = f"bytes={existing_size}-"
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as resp:
mode = "ab" if existing_size > 0 else "wb"
with open(path, mode) as f:
async for chunk in resp.content.iter_chunked(64 * 1024):
f.write(chunk)
async def _report_status(self, status: UpdateStatus, error: str = None):
"""Report update progress to cloud."""
async with aiohttp.ClientSession() as session:
await session.post(
f"{self.server_url}/api/updates/status",
json={
"device_id": self.device_id,
"status": status.value,
"version": self.current_version,
"error": error,
}
)
# Edge device main loop
async def update_loop(client: OTAUpdateClient):
while True:
manifest = await client.check_for_update()
if manifest:
success = await client.apply_update(manifest)
if success:
print(f"Updated to model v{manifest.version}")
else:
print("Update failed, running previous version")
await asyncio.sleep(300) # Check every 5 minutes
Staged Rollout Strategy
Never push a model to all devices at once. Use staged rollouts to catch issues early:
| Stage | Devices | Duration | Success Criteria | Action on Failure |
|---|---|---|---|---|
| Canary | 1% (10 devices) | 24 hours | No crashes, accuracy ≥ baseline, latency ≤ 1.2x | Halt rollout, auto-rollback canary devices |
| Early Adopter | 10% (100 devices) | 48 hours | Error rate < 0.1%, no new edge cases | Halt rollout, investigate, decide rollback |
| Gradual | 50% (500 devices) | 48 hours | Metrics stable, no support tickets | Pause and investigate |
| Full | 100% (1,000 devices) | Immediate | Metrics stable for 7 days post-rollout | Emergency rollback via fleet command |
Data Collection from Edge
Edge devices generate valuable data for model retraining. The challenge is collecting it efficiently without overwhelming bandwidth:
import sqlite3
import json
import gzip
import time
from typing import List, Dict
class EdgeDataCollector:
"""
Collects inference metadata on edge devices and syncs to cloud.
NEVER sends raw images/video -- only metadata and selected samples.
"""
def __init__(self, db_path: str = "edge_data.db"):
self.db = sqlite3.connect(db_path)
self.db.execute("""
CREATE TABLE IF NOT EXISTS inference_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp REAL,
model_version TEXT,
prediction TEXT,
confidence REAL,
latency_ms REAL,
input_hash TEXT,
synced INTEGER DEFAULT 0
)
""")
self.db.execute("""
CREATE TABLE IF NOT EXISTS hard_samples (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp REAL,
input_data BLOB,
prediction TEXT,
confidence REAL,
synced INTEGER DEFAULT 0
)
""")
def log_inference(self, prediction: str, confidence: float,
latency_ms: float, input_data: bytes = None):
"""Log every inference. Store hard samples for retraining."""
input_hash = hashlib.md5(input_data).hexdigest() if input_data else None
self.db.execute(
"INSERT INTO inference_log VALUES (NULL,?,?,?,?,?,?,0)",
(time.time(), self.model_version, prediction, confidence,
latency_ms, input_hash)
)
# Save hard samples: low confidence = model is uncertain
if confidence < 0.7 and input_data:
self.db.execute(
"INSERT INTO hard_samples VALUES (NULL,?,?,?,?,0)",
(time.time(), input_data, prediction, confidence)
)
self.db.commit()
def get_sync_batch(self, max_size_mb: float = 1.0) -> Dict:
"""
Prepare a batch for cloud sync.
Prioritize: hard samples > recent logs > old logs.
"""
batch = {"metadata": [], "hard_samples": []}
current_size = 0
max_bytes = max_size_mb * 1024 * 1024
# 1. Include all unsynced metadata (small, always send)
rows = self.db.execute(
"SELECT * FROM inference_log WHERE synced=0 ORDER BY timestamp DESC LIMIT 10000"
).fetchall()
for row in rows:
entry = {
"id": row[0], "timestamp": row[1], "model_version": row[2],
"prediction": row[3], "confidence": row[4], "latency_ms": row[5]
}
batch["metadata"].append(entry)
current_size += len(json.dumps(entry))
# 2. Include hard samples up to size limit
samples = self.db.execute(
"SELECT * FROM hard_samples WHERE synced=0 ORDER BY confidence ASC LIMIT 100"
).fetchall()
for sample in samples:
if current_size + len(sample[2]) > max_bytes:
break
batch["hard_samples"].append({
"id": sample[0], "timestamp": sample[1],
"input_data": sample[2], # Base64 in production
"prediction": sample[3], "confidence": sample[4]
})
current_size += len(sample[2])
return batch
def mark_synced(self, metadata_ids: List[int], sample_ids: List[int]):
"""Mark records as synced after successful upload."""
if metadata_ids:
placeholders = ",".join("?" * len(metadata_ids))
self.db.execute(
f"UPDATE inference_log SET synced=1 WHERE id IN ({placeholders})",
metadata_ids
)
if sample_ids:
placeholders = ",".join("?" * len(sample_ids))
self.db.execute(
f"UPDATE hard_samples SET synced=1 WHERE id IN ({placeholders})",
sample_ids
)
self.db.commit()
Federated Learning Basics
Federated learning trains the model on edge devices without sending raw data to the cloud. Each device trains locally and only sends model weight updates:
import numpy as np
from typing import List, Dict
class FederatedServer:
"""
Cloud server that orchestrates federated learning rounds.
Devices train locally, server aggregates weight updates.
"""
def __init__(self, global_model_weights: np.ndarray):
self.global_weights = global_model_weights
self.round_number = 0
def start_round(self, num_devices: int = 10) -> Dict:
"""Select devices and send current global model."""
self.round_number += 1
return {
"round": self.round_number,
"global_weights": self.global_weights.tolist(),
"local_epochs": 3, # Train 3 epochs locally
"learning_rate": 0.001,
"min_samples": 50, # Device needs >= 50 samples
}
def aggregate_updates(self, device_updates: List[Dict]) -> np.ndarray:
"""
Federated Averaging (FedAvg): weighted average of device updates.
Weight by number of local samples (devices with more data count more).
"""
total_samples = sum(u["num_samples"] for u in device_updates)
aggregated = np.zeros_like(self.global_weights)
for update in device_updates:
weight = update["num_samples"] / total_samples
aggregated += weight * np.array(update["weights"])
self.global_weights = aggregated
return self.global_weights
class FederatedClient:
"""Runs on each edge device. Trains locally, sends updates."""
def __init__(self, device_id: str, local_data):
self.device_id = device_id
self.local_data = local_data
def train_local(self, round_config: Dict) -> Dict:
"""Train on local data and return weight updates."""
# Load global weights into local model
model = self._create_model()
model.set_weights(round_config["global_weights"])
# Train on LOCAL data only (data never leaves device)
model.fit(
self.local_data.x, self.local_data.y,
epochs=round_config["local_epochs"],
learning_rate=round_config["learning_rate"]
)
return {
"device_id": self.device_id,
"weights": model.get_weights(), # Only weights sent to cloud
"num_samples": len(self.local_data),
"local_loss": model.evaluate(self.local_data),
}
# Federated learning round (orchestrated by cloud)
# 1. Server sends global model to 10 selected devices
# 2. Each device trains 3 epochs on local data (~30 seconds)
# 3. Devices send weight updates to server (a few MB each)
# 4. Server averages updates (FedAvg) to create new global model
# 5. Repeat every 24 hours
Sync Strategies
How often and when to sync depends on your connectivity profile and data urgency:
| Strategy | When to Use | Bandwidth | Data Freshness |
|---|---|---|---|
| Periodic (every N minutes) | Stable connectivity, predictable data | Predictable, schedulable | N minutes stale |
| Event-driven (on trigger) | Critical detections, anomalies | Bursty, unpredictable | Near real-time for events |
| Opportunistic (when WiFi available) | Mobile devices, field equipment | Efficient, uses cheap bandwidth | Hours to days |
| Batched (nightly/weekly) | Low-priority analytics, model data | Minimal, compressed | 24 hours+ |
class SyncScheduler:
"""Manages when and how to sync edge data to cloud."""
def __init__(self, strategy: str = "adaptive"):
self.strategy = strategy
self.pending_events = []
self.last_sync = time.time()
async def should_sync(self) -> bool:
if self.strategy == "periodic":
return time.time() - self.last_sync > 300 # Every 5 min
elif self.strategy == "event_driven":
return len(self.pending_events) > 0
elif self.strategy == "opportunistic":
return self._is_on_wifi() and self._battery_above(20)
elif self.strategy == "adaptive":
# Smart: sync critical events immediately, batch the rest
has_critical = any(e.priority == "critical" for e in self.pending_events)
if has_critical:
return True
if self._is_on_wifi() and len(self.pending_events) > 100:
return True
return time.time() - self.last_sync > 3600 # At least hourly
def _is_on_wifi(self) -> bool:
"""Check if connected to WiFi (not cellular)."""
# Platform-specific implementation
return check_network_type() == "wifi"
def _battery_above(self, percent: int) -> bool:
"""Don't sync when battery is low."""
return get_battery_level() > percent
Key Takeaways
- OTA model updates require download-with-resume, SHA256 integrity checks, validation inference, atomic swap, and automatic rollback. Never skip any step.
- Use staged rollouts: 1% canary (24h), 10% early adopter (48h), 50% gradual (48h), then 100%. Monitor accuracy, latency, and crash rates at each stage.
- Collect inference metadata from every device (small, always sync), and save hard samples (low confidence predictions) for retraining.
- Federated learning keeps data on-device but adds complexity. Start with centralized data collection unless privacy requirements mandate on-device training.
- Use adaptive sync: critical events sync immediately, bulk data syncs on WiFi when battery is sufficient. Budget bandwidth as a line item in your edge hardware BOM.
What Is Next
In the next lesson, we will cover offline and intermittent connectivity — how to build edge AI systems that work reliably with no internet, handle network drops gracefully, and recover without data loss.
Lilly Tech Systems