Intermediate

Model Optimization for Edge

A ResNet-50 is 97MB. A Jetson Nano has 4GB of shared RAM. A Coral TPU only runs INT8 models. Edge deployment means making your model smaller, faster, and more efficient — without destroying accuracy. This lesson covers the four core optimization techniques with production code you can run today.

Optimization Techniques Overview

There are four main techniques for making models edge-ready. They can be combined for maximum compression:

TechniqueSize ReductionSpeed ImprovementAccuracy ImpactEffort
Post-Training Quantization (PTQ) 4x (FP32 → INT8) 2-4x 0.5-2% drop Low (10 min)
Quantization-Aware Training (QAT) 4x (FP32 → INT8) 2-4x 0.1-0.5% drop Medium (retrain)
Pruning (Structured) 2-5x 2-3x 0.5-3% drop Medium
Knowledge Distillation 10-50x 5-20x 1-5% drop High (train student)

Post-Training Quantization (PTQ)

PTQ converts a trained FP32 model to INT8 without any retraining. It is the fastest path from cloud model to edge deployment:

import tensorflow as tf
import numpy as np

# Step 1: Load your trained model
model = tf.keras.models.load_model("resnet50_trained.h5")
print(f"Original model size: {os.path.getsize('resnet50_trained.h5') / 1e6:.1f} MB")
# -> Original model size: 97.8 MB

# Step 2: Convert to TFLite with INT8 quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

# Provide a representative dataset for calibration (100-500 samples)
def representative_dataset():
    """Feed real data so the converter can determine value ranges."""
    for i in range(200):
        sample = training_data[i:i+1].astype(np.float32)
        yield [sample]

converter.representative_dataset = representative_dataset

# Force all ops to INT8 (required for Coral TPU, optimal for all edge)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

# Convert
tflite_model = converter.convert()

# Save
with open("resnet50_int8.tflite", "wb") as f:
    f.write(tflite_model)

print(f"Quantized model size: {len(tflite_model) / 1e6:.1f} MB")
# -> Quantized model size: 24.5 MB (4x smaller)
print(f"Compression ratio: {97.8 / 24.5:.1f}x")

Quantization-Aware Training (QAT)

QAT simulates quantization during training, so the model learns to be accurate with reduced precision. Use this when PTQ accuracy is not good enough:

import tensorflow as tf
import tensorflow_model_optimization as tfmot

# Step 1: Start with your trained FP32 model
base_model = tf.keras.models.load_model("resnet50_trained.h5")

# Step 2: Apply QAT - inserts fake quantization nodes
quantize_model = tfmot.quantization.keras.quantize_model
qat_model = quantize_model(base_model)

# Step 3: Fine-tune with quantization simulation (10-20% of original epochs)
qat_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),  # Lower LR
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

qat_model.fit(
    train_data,
    epochs=5,            # 10-20% of original training
    validation_data=val_data,
    batch_size=32
)

# Step 4: Convert to TFLite INT8
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open("resnet50_qat_int8.tflite", "wb") as f:
    f.write(tflite_model)

# Benchmark comparison:
# PTQ INT8:  Top-1 accuracy 74.8% (was 76.1% FP32) - 1.3% drop
# QAT INT8:  Top-1 accuracy 75.7% (was 76.1% FP32) - 0.4% drop

Structured Pruning

Pruning removes weights or entire filters that contribute least to the output. Structured pruning removes whole filters, making the model actually faster (not just smaller):

import torch
import torch.nn.utils.prune as prune

# Step 1: Load trained PyTorch model
model = torch.load("resnet50_trained.pt")
model.eval()

def apply_structured_pruning(model, amount=0.3):
    """
    Remove 30% of filters from each conv layer based on L1 norm.
    Filters with smallest weight magnitudes contribute least.
    """
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.ln_structured(
                module,
                name="weight",
                amount=amount,  # Remove 30% of filters
                n=1,            # L1 norm
                dim=0           # Prune along output channel dimension
            )
            # Make pruning permanent (remove mask, zero out weights)
            prune.remove(module, "weight")

    return model

# Step 2: Prune
pruned_model = apply_structured_pruning(model, amount=0.3)

# Step 3: Fine-tune to recover accuracy (critical step)
optimizer = torch.optim.SGD(pruned_model.parameters(), lr=1e-4, momentum=0.9)
for epoch in range(10):
    for batch in train_loader:
        images, labels = batch
        outputs = pruned_model(images)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Step 4: Export to ONNX for edge deployment
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    pruned_model, dummy_input, "resnet50_pruned.onnx",
    input_names=["input"], output_names=["output"],
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
)

# Results:
# Original:  97MB, 4.1 GFLOPs, 76.1% accuracy
# 30% pruned: 68MB, 2.9 GFLOPs, 75.4% accuracy (after fine-tuning)
# 50% pruned: 49MB, 2.1 GFLOPs, 74.2% accuracy (after fine-tuning)

Knowledge Distillation

Distillation trains a small "student" model to mimic a large "teacher" model. The student learns the teacher's soft probability distributions, which contain more information than hard labels:

import torch
import torch.nn.functional as F

class DistillationTrainer:
    """
    Train a small student model to mimic a large teacher model.
    The student learns from both the true labels AND the teacher's
    soft predictions (which encode inter-class relationships).
    """
    def __init__(self, teacher, student, temperature=4.0, alpha=0.7):
        self.teacher = teacher.eval()  # Frozen, no gradient
        self.student = student
        self.temperature = temperature  # Higher = softer distributions
        self.alpha = alpha              # Weight of distillation vs hard loss

    def distillation_loss(self, student_logits, teacher_logits, true_labels):
        # Soft loss: KL divergence between teacher and student soft outputs
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        soft_loss = F.kl_div(
            soft_student, soft_teacher, reduction="batchmean"
        ) * (self.temperature ** 2)

        # Hard loss: standard cross-entropy with true labels
        hard_loss = F.cross_entropy(student_logits, true_labels)

        # Combined loss
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

    def train_epoch(self, dataloader, optimizer):
        self.student.train()
        total_loss = 0

        for images, labels in dataloader:
            # Teacher prediction (no gradient needed)
            with torch.no_grad():
                teacher_logits = self.teacher(images)

            # Student prediction
            student_logits = self.student(images)

            # Combined loss
            loss = self.distillation_loss(
                student_logits, teacher_logits, labels
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        return total_loss / len(dataloader)

# Usage: Distill ResNet-50 (25M params) into MobileNet-v3 (5M params)
teacher = load_model("resnet50_trained.pt")        # 97MB, 76.1% acc
student = MobileNetV3Small(num_classes=1000)        # 10MB

trainer = DistillationTrainer(teacher, student, temperature=4.0, alpha=0.7)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)

for epoch in range(100):
    loss = trainer.train_epoch(train_loader, optimizer)
    val_acc = evaluate(student, val_loader)
    print(f"Epoch {epoch}: loss={loss:.4f}, val_acc={val_acc:.2%}")

# Results:
# MobileNet-v3 trained from scratch:    67.4% accuracy, 10MB
# MobileNet-v3 distilled from ResNet:   72.8% accuracy, 10MB (+5.4%!)
# Teacher ResNet-50:                     76.1% accuracy, 97MB
💡
Apply at work: The optimal pipeline for most edge deployments: (1) Train a large model in the cloud, (2) Distill into MobileNet-v3 or EfficientNet-Lite, (3) Apply QAT on the student, (4) Convert to TFLite INT8. This gives you 10-40x compression with only 2-3% accuracy loss.

Edge-Optimized Architectures

Instead of compressing large models, start with architectures designed for edge from the ground up:

ArchitectureParamsSize (INT8)Latency (Coral)ImageNet Top-1Best For
MobileNet-v3 Small 2.5M 2.4MB 3.5ms 67.4% Ultra-low latency, microcontrollers
MobileNet-v3 Large 5.4M 5.2MB 6.2ms 75.2% Balanced accuracy/speed
EfficientNet-Lite0 4.7M 4.4MB 7.8ms 75.1% Best accuracy per FLOP
EfficientNet-Lite4 13M 12.5MB 52ms 80.2% High accuracy edge (Jetson)
YOLO-NAS-S 12.2M 12MB 15ms (Jetson) 47.5 mAP Real-time object detection

ONNX Conversion Pipeline

ONNX is the universal exchange format for edge models. Convert once, deploy to any runtime:

import torch
import onnx
import onnxruntime as ort
import numpy as np

# Step 1: Export PyTorch model to ONNX
model = load_trained_model()
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"}
    },
    opset_version=17,  # Use latest stable opset
    do_constant_folding=True  # Optimize constant expressions
)

# Step 2: Validate the ONNX model
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print(f"ONNX model validated: {len(onnx_model.graph.node)} ops")

# Step 3: Optimize with ONNX Runtime
# This fuses operations and optimizes the graph for inference
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.optimized_model_filepath = "model_optimized.onnx"
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

session = ort.InferenceSession("model.onnx", sess_options)

# Step 4: Benchmark
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

import time
times = []
for _ in range(100):
    start = time.perf_counter()
    result = session.run(None, {"input": input_data})
    times.append(time.perf_counter() - start)

print(f"Avg inference: {np.mean(times)*1000:.1f}ms")
print(f"P95 inference: {np.percentile(times, 95)*1000:.1f}ms")

# Step 5: Convert ONNX to target runtime
# TFLite:    onnx2tf model.onnx -o model_tflite/
# CoreML:    coremltools.convert(onnx_model)
# TensorRT:  trtexec --onnx=model.onnx --saveEngine=model.engine
📝
Production reality: Not all ONNX ops are supported by all runtimes. TFLite does not support dynamic shapes well. CoreML requires specific op versions. Always test the converted model against your test set after conversion — numerical differences between runtimes can cause accuracy drops beyond what quantization alone would predict.

Size/Accuracy Trade-Off Benchmarks

Here is a comprehensive benchmark comparing optimization pipelines on ImageNet classification. Use this to choose the right approach for your constraints:

Model + OptimizationSizeTop-1 AccJetson Nano LatencyCoral Latency
ResNet-50 FP32 (baseline) 97MB 76.1% 45ms N/A
ResNet-50 PTQ INT8 24MB 74.8% 18ms 52ms
ResNet-50 QAT INT8 24MB 75.7% 18ms 52ms
MobileNet-v3 Large FP32 21MB 75.2% 12ms N/A
MobileNet-v3 Large INT8 5.2MB 73.8% 5ms 6.2ms
MobileNet-v3 Distilled + QAT INT8 5.2MB 74.9% 5ms 6.2ms
EfficientNet-Lite0 INT8 4.4MB 73.6% 8ms 7.8ms

Key Takeaways

  • PTQ is the fastest optimization path: 10 minutes of work for 4x compression with 0.5-2% accuracy loss. Start here for every edge project.
  • QAT recovers most of the PTQ accuracy loss by simulating quantization during training. Use when you need every 0.5% of accuracy.
  • Structured pruning removes whole filters, giving real speedups (not just compression). Combine with quantization for 8-16x total compression.
  • Knowledge distillation is the most powerful technique: train a MobileNet student from a ResNet teacher to get +5% accuracy over training from scratch.
  • The optimal pipeline is: cloud-train large model, distill to edge architecture, apply QAT, convert to INT8 TFLite/ONNX. This gives 10-40x compression with 2-3% accuracy loss.

What Is Next

In the next lesson, we will cover edge inference runtimes — TFLite, CoreML, ONNX Runtime, TensorRT, and OpenVINO. You will learn which runtime to use for each hardware platform and see deployment code for each.