Intermediate

Edge-Cloud Synchronization

Edge devices are not islands. Models need updates, devices need monitoring, and the data collected at the edge is gold for retraining. This lesson covers the production patterns for keeping edge devices and cloud in sync — OTA model distribution, data collection pipelines, federated learning, and bandwidth management.

OTA Model Update Distribution

Pushing a new model to 10,000 devices reliably is harder than training the model. Here is a production OTA update system:

import hashlib
import json
import aiohttp
import asyncio
from pathlib import Path
from dataclasses import dataclass
from enum import Enum

class UpdateStatus(Enum):
    PENDING = "pending"
    DOWNLOADING = "downloading"
    VALIDATING = "validating"
    APPLYING = "applying"
    ACTIVE = "active"
    ROLLED_BACK = "rolled_back"
    FAILED = "failed"

@dataclass
class ModelManifest:
    """Describes a model version for OTA distribution."""
    model_id: str
    version: str
    url: str              # CDN URL for download
    sha256: str           # Integrity check
    size_bytes: int
    min_device_ram_mb: int
    min_firmware_version: str
    rollout_percentage: int  # 0-100, for staged rollouts

class OTAUpdateClient:
    """
    Runs on each edge device. Checks for updates, downloads,
    validates, and applies new models with automatic rollback.
    """
    def __init__(self, device_id: str, server_url: str, model_dir: str):
        self.device_id = device_id
        self.server_url = server_url
        self.model_dir = Path(model_dir)
        self.current_version = self._load_current_version()

    async def check_for_update(self) -> ModelManifest | None:
        """Poll server for available updates."""
        async with aiohttp.ClientSession() as session:
            resp = await session.get(
                f"{self.server_url}/api/updates/check",
                params={
                    "device_id": self.device_id,
                    "current_version": self.current_version,
                    "device_ram_mb": self._get_ram_mb(),
                    "firmware_version": self._get_firmware_version(),
                }
            )
            data = await resp.json()

            if data.get("update_available"):
                return ModelManifest(**data["manifest"])
            return None

    async def apply_update(self, manifest: ModelManifest) -> bool:
        """Download, validate, and swap to new model."""
        new_path = self.model_dir / f"model_v{manifest.version}.tflite"
        old_path = self.model_dir / "model_current.tflite"
        backup_path = self.model_dir / "model_backup.tflite"

        try:
            # Step 1: Download with progress tracking
            await self._report_status(UpdateStatus.DOWNLOADING)
            await self._download_with_resume(manifest.url, new_path, manifest.size_bytes)

            # Step 2: Validate integrity
            await self._report_status(UpdateStatus.VALIDATING)
            actual_hash = self._compute_sha256(new_path)
            if actual_hash != manifest.sha256:
                raise ValueError(f"Hash mismatch: {actual_hash} != {manifest.sha256}")

            # Step 3: Validate model loads and produces output
            test_output = self._test_model(new_path)
            if test_output is None:
                raise ValueError("Model failed validation inference")

            # Step 4: Atomic swap (backup old, activate new)
            await self._report_status(UpdateStatus.APPLYING)
            if old_path.exists():
                old_path.rename(backup_path)  # Keep backup for rollback
            new_path.rename(old_path)

            self.current_version = manifest.version
            await self._report_status(UpdateStatus.ACTIVE)
            return True

        except Exception as e:
            # Rollback: restore backup
            if backup_path.exists() and not old_path.exists():
                backup_path.rename(old_path)
            await self._report_status(UpdateStatus.FAILED, error=str(e))
            return False

    async def _download_with_resume(self, url: str, path: Path, total_size: int):
        """Download with resume support for unreliable connections."""
        existing_size = path.stat().st_size if path.exists() else 0

        headers = {}
        if existing_size > 0:
            headers["Range"] = f"bytes={existing_size}-"

        async with aiohttp.ClientSession() as session:
            async with session.get(url, headers=headers) as resp:
                mode = "ab" if existing_size > 0 else "wb"
                with open(path, mode) as f:
                    async for chunk in resp.content.iter_chunked(64 * 1024):
                        f.write(chunk)

    async def _report_status(self, status: UpdateStatus, error: str = None):
        """Report update progress to cloud."""
        async with aiohttp.ClientSession() as session:
            await session.post(
                f"{self.server_url}/api/updates/status",
                json={
                    "device_id": self.device_id,
                    "status": status.value,
                    "version": self.current_version,
                    "error": error,
                }
            )

# Edge device main loop
async def update_loop(client: OTAUpdateClient):
    while True:
        manifest = await client.check_for_update()
        if manifest:
            success = await client.apply_update(manifest)
            if success:
                print(f"Updated to model v{manifest.version}")
            else:
                print("Update failed, running previous version")
        await asyncio.sleep(300)  # Check every 5 minutes

Staged Rollout Strategy

Never push a model to all devices at once. Use staged rollouts to catch issues early:

StageDevicesDurationSuccess CriteriaAction on Failure
Canary 1% (10 devices) 24 hours No crashes, accuracy ≥ baseline, latency ≤ 1.2x Halt rollout, auto-rollback canary devices
Early Adopter 10% (100 devices) 48 hours Error rate < 0.1%, no new edge cases Halt rollout, investigate, decide rollback
Gradual 50% (500 devices) 48 hours Metrics stable, no support tickets Pause and investigate
Full 100% (1,000 devices) Immediate Metrics stable for 7 days post-rollout Emergency rollback via fleet command

Data Collection from Edge

Edge devices generate valuable data for model retraining. The challenge is collecting it efficiently without overwhelming bandwidth:

import sqlite3
import json
import gzip
import time
from typing import List, Dict

class EdgeDataCollector:
    """
    Collects inference metadata on edge devices and syncs to cloud.
    NEVER sends raw images/video -- only metadata and selected samples.
    """
    def __init__(self, db_path: str = "edge_data.db"):
        self.db = sqlite3.connect(db_path)
        self.db.execute("""
            CREATE TABLE IF NOT EXISTS inference_log (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp REAL,
                model_version TEXT,
                prediction TEXT,
                confidence REAL,
                latency_ms REAL,
                input_hash TEXT,
                synced INTEGER DEFAULT 0
            )
        """)
        self.db.execute("""
            CREATE TABLE IF NOT EXISTS hard_samples (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp REAL,
                input_data BLOB,
                prediction TEXT,
                confidence REAL,
                synced INTEGER DEFAULT 0
            )
        """)

    def log_inference(self, prediction: str, confidence: float,
                      latency_ms: float, input_data: bytes = None):
        """Log every inference. Store hard samples for retraining."""
        input_hash = hashlib.md5(input_data).hexdigest() if input_data else None

        self.db.execute(
            "INSERT INTO inference_log VALUES (NULL,?,?,?,?,?,?,0)",
            (time.time(), self.model_version, prediction, confidence,
             latency_ms, input_hash)
        )

        # Save hard samples: low confidence = model is uncertain
        if confidence < 0.7 and input_data:
            self.db.execute(
                "INSERT INTO hard_samples VALUES (NULL,?,?,?,?,0)",
                (time.time(), input_data, prediction, confidence)
            )

        self.db.commit()

    def get_sync_batch(self, max_size_mb: float = 1.0) -> Dict:
        """
        Prepare a batch for cloud sync.
        Prioritize: hard samples > recent logs > old logs.
        """
        batch = {"metadata": [], "hard_samples": []}
        current_size = 0
        max_bytes = max_size_mb * 1024 * 1024

        # 1. Include all unsynced metadata (small, always send)
        rows = self.db.execute(
            "SELECT * FROM inference_log WHERE synced=0 ORDER BY timestamp DESC LIMIT 10000"
        ).fetchall()

        for row in rows:
            entry = {
                "id": row[0], "timestamp": row[1], "model_version": row[2],
                "prediction": row[3], "confidence": row[4], "latency_ms": row[5]
            }
            batch["metadata"].append(entry)
            current_size += len(json.dumps(entry))

        # 2. Include hard samples up to size limit
        samples = self.db.execute(
            "SELECT * FROM hard_samples WHERE synced=0 ORDER BY confidence ASC LIMIT 100"
        ).fetchall()

        for sample in samples:
            if current_size + len(sample[2]) > max_bytes:
                break
            batch["hard_samples"].append({
                "id": sample[0], "timestamp": sample[1],
                "input_data": sample[2],  # Base64 in production
                "prediction": sample[3], "confidence": sample[4]
            })
            current_size += len(sample[2])

        return batch

    def mark_synced(self, metadata_ids: List[int], sample_ids: List[int]):
        """Mark records as synced after successful upload."""
        if metadata_ids:
            placeholders = ",".join("?" * len(metadata_ids))
            self.db.execute(
                f"UPDATE inference_log SET synced=1 WHERE id IN ({placeholders})",
                metadata_ids
            )
        if sample_ids:
            placeholders = ",".join("?" * len(sample_ids))
            self.db.execute(
                f"UPDATE hard_samples SET synced=1 WHERE id IN ({placeholders})",
                sample_ids
            )
        self.db.commit()

Federated Learning Basics

Federated learning trains the model on edge devices without sending raw data to the cloud. Each device trains locally and only sends model weight updates:

import numpy as np
from typing import List, Dict

class FederatedServer:
    """
    Cloud server that orchestrates federated learning rounds.
    Devices train locally, server aggregates weight updates.
    """
    def __init__(self, global_model_weights: np.ndarray):
        self.global_weights = global_model_weights
        self.round_number = 0

    def start_round(self, num_devices: int = 10) -> Dict:
        """Select devices and send current global model."""
        self.round_number += 1
        return {
            "round": self.round_number,
            "global_weights": self.global_weights.tolist(),
            "local_epochs": 3,       # Train 3 epochs locally
            "learning_rate": 0.001,
            "min_samples": 50,       # Device needs >= 50 samples
        }

    def aggregate_updates(self, device_updates: List[Dict]) -> np.ndarray:
        """
        Federated Averaging (FedAvg): weighted average of device updates.
        Weight by number of local samples (devices with more data count more).
        """
        total_samples = sum(u["num_samples"] for u in device_updates)

        aggregated = np.zeros_like(self.global_weights)
        for update in device_updates:
            weight = update["num_samples"] / total_samples
            aggregated += weight * np.array(update["weights"])

        self.global_weights = aggregated
        return self.global_weights

class FederatedClient:
    """Runs on each edge device. Trains locally, sends updates."""

    def __init__(self, device_id: str, local_data):
        self.device_id = device_id
        self.local_data = local_data

    def train_local(self, round_config: Dict) -> Dict:
        """Train on local data and return weight updates."""
        # Load global weights into local model
        model = self._create_model()
        model.set_weights(round_config["global_weights"])

        # Train on LOCAL data only (data never leaves device)
        model.fit(
            self.local_data.x, self.local_data.y,
            epochs=round_config["local_epochs"],
            learning_rate=round_config["learning_rate"]
        )

        return {
            "device_id": self.device_id,
            "weights": model.get_weights(),  # Only weights sent to cloud
            "num_samples": len(self.local_data),
            "local_loss": model.evaluate(self.local_data),
        }

# Federated learning round (orchestrated by cloud)
# 1. Server sends global model to 10 selected devices
# 2. Each device trains 3 epochs on local data (~30 seconds)
# 3. Devices send weight updates to server (a few MB each)
# 4. Server averages updates (FedAvg) to create new global model
# 5. Repeat every 24 hours
💡
Apply at work: Federated learning is powerful but complex. Start with simple data collection (inference logs + hard samples) and centralized retraining. Move to federated learning only when you have privacy requirements that prevent any data from leaving devices, or when the edge data is too large to upload (e.g., hours of video per device per day).

Sync Strategies

How often and when to sync depends on your connectivity profile and data urgency:

StrategyWhen to UseBandwidthData Freshness
Periodic (every N minutes) Stable connectivity, predictable data Predictable, schedulable N minutes stale
Event-driven (on trigger) Critical detections, anomalies Bursty, unpredictable Near real-time for events
Opportunistic (when WiFi available) Mobile devices, field equipment Efficient, uses cheap bandwidth Hours to days
Batched (nightly/weekly) Low-priority analytics, model data Minimal, compressed 24 hours+
class SyncScheduler:
    """Manages when and how to sync edge data to cloud."""

    def __init__(self, strategy: str = "adaptive"):
        self.strategy = strategy
        self.pending_events = []
        self.last_sync = time.time()

    async def should_sync(self) -> bool:
        if self.strategy == "periodic":
            return time.time() - self.last_sync > 300  # Every 5 min

        elif self.strategy == "event_driven":
            return len(self.pending_events) > 0

        elif self.strategy == "opportunistic":
            return self._is_on_wifi() and self._battery_above(20)

        elif self.strategy == "adaptive":
            # Smart: sync critical events immediately, batch the rest
            has_critical = any(e.priority == "critical" for e in self.pending_events)
            if has_critical:
                return True
            if self._is_on_wifi() and len(self.pending_events) > 100:
                return True
            return time.time() - self.last_sync > 3600  # At least hourly

    def _is_on_wifi(self) -> bool:
        """Check if connected to WiFi (not cellular)."""
        # Platform-specific implementation
        return check_network_type() == "wifi"

    def _battery_above(self, percent: int) -> bool:
        """Don't sync when battery is low."""
        return get_battery_level() > percent
📝
Production reality: Bandwidth costs dominate edge AI operations at scale. A 1000-device fleet syncing 1MB every 5 minutes generates 288GB/day of upstream traffic. Use compression (gzip reduces JSON 5-10x), send metadata only (not raw inputs), and sync hard samples at lower priority. Your bandwidth budget should be part of the hardware BOM.

Key Takeaways

  • OTA model updates require download-with-resume, SHA256 integrity checks, validation inference, atomic swap, and automatic rollback. Never skip any step.
  • Use staged rollouts: 1% canary (24h), 10% early adopter (48h), 50% gradual (48h), then 100%. Monitor accuracy, latency, and crash rates at each stage.
  • Collect inference metadata from every device (small, always sync), and save hard samples (low confidence predictions) for retraining.
  • Federated learning keeps data on-device but adds complexity. Start with centralized data collection unless privacy requirements mandate on-device training.
  • Use adaptive sync: critical events sync immediately, bulk data syncs on WiFi when battery is sufficient. Budget bandwidth as a line item in your edge hardware BOM.

What Is Next

In the next lesson, we will cover offline and intermittent connectivity — how to build edge AI systems that work reliably with no internet, handle network drops gracefully, and recover without data loss.