Model Optimization for Edge
A ResNet-50 is 97MB. A Jetson Nano has 4GB of shared RAM. A Coral TPU only runs INT8 models. Edge deployment means making your model smaller, faster, and more efficient — without destroying accuracy. This lesson covers the four core optimization techniques with production code you can run today.
Optimization Techniques Overview
There are four main techniques for making models edge-ready. They can be combined for maximum compression:
| Technique | Size Reduction | Speed Improvement | Accuracy Impact | Effort |
|---|---|---|---|---|
| Post-Training Quantization (PTQ) | 4x (FP32 → INT8) | 2-4x | 0.5-2% drop | Low (10 min) |
| Quantization-Aware Training (QAT) | 4x (FP32 → INT8) | 2-4x | 0.1-0.5% drop | Medium (retrain) |
| Pruning (Structured) | 2-5x | 2-3x | 0.5-3% drop | Medium |
| Knowledge Distillation | 10-50x | 5-20x | 1-5% drop | High (train student) |
Post-Training Quantization (PTQ)
PTQ converts a trained FP32 model to INT8 without any retraining. It is the fastest path from cloud model to edge deployment:
import tensorflow as tf
import numpy as np
# Step 1: Load your trained model
model = tf.keras.models.load_model("resnet50_trained.h5")
print(f"Original model size: {os.path.getsize('resnet50_trained.h5') / 1e6:.1f} MB")
# -> Original model size: 97.8 MB
# Step 2: Convert to TFLite with INT8 quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Provide a representative dataset for calibration (100-500 samples)
def representative_dataset():
"""Feed real data so the converter can determine value ranges."""
for i in range(200):
sample = training_data[i:i+1].astype(np.float32)
yield [sample]
converter.representative_dataset = representative_dataset
# Force all ops to INT8 (required for Coral TPU, optimal for all edge)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
# Convert
tflite_model = converter.convert()
# Save
with open("resnet50_int8.tflite", "wb") as f:
f.write(tflite_model)
print(f"Quantized model size: {len(tflite_model) / 1e6:.1f} MB")
# -> Quantized model size: 24.5 MB (4x smaller)
print(f"Compression ratio: {97.8 / 24.5:.1f}x")
Quantization-Aware Training (QAT)
QAT simulates quantization during training, so the model learns to be accurate with reduced precision. Use this when PTQ accuracy is not good enough:
import tensorflow as tf
import tensorflow_model_optimization as tfmot
# Step 1: Start with your trained FP32 model
base_model = tf.keras.models.load_model("resnet50_trained.h5")
# Step 2: Apply QAT - inserts fake quantization nodes
quantize_model = tfmot.quantization.keras.quantize_model
qat_model = quantize_model(base_model)
# Step 3: Fine-tune with quantization simulation (10-20% of original epochs)
qat_model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), # Lower LR
loss="categorical_crossentropy",
metrics=["accuracy"]
)
qat_model.fit(
train_data,
epochs=5, # 10-20% of original training
validation_data=val_data,
batch_size=32
)
# Step 4: Convert to TFLite INT8
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with open("resnet50_qat_int8.tflite", "wb") as f:
f.write(tflite_model)
# Benchmark comparison:
# PTQ INT8: Top-1 accuracy 74.8% (was 76.1% FP32) - 1.3% drop
# QAT INT8: Top-1 accuracy 75.7% (was 76.1% FP32) - 0.4% drop
Structured Pruning
Pruning removes weights or entire filters that contribute least to the output. Structured pruning removes whole filters, making the model actually faster (not just smaller):
import torch
import torch.nn.utils.prune as prune
# Step 1: Load trained PyTorch model
model = torch.load("resnet50_trained.pt")
model.eval()
def apply_structured_pruning(model, amount=0.3):
"""
Remove 30% of filters from each conv layer based on L1 norm.
Filters with smallest weight magnitudes contribute least.
"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.ln_structured(
module,
name="weight",
amount=amount, # Remove 30% of filters
n=1, # L1 norm
dim=0 # Prune along output channel dimension
)
# Make pruning permanent (remove mask, zero out weights)
prune.remove(module, "weight")
return model
# Step 2: Prune
pruned_model = apply_structured_pruning(model, amount=0.3)
# Step 3: Fine-tune to recover accuracy (critical step)
optimizer = torch.optim.SGD(pruned_model.parameters(), lr=1e-4, momentum=0.9)
for epoch in range(10):
for batch in train_loader:
images, labels = batch
outputs = pruned_model(images)
loss = torch.nn.functional.cross_entropy(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Step 4: Export to ONNX for edge deployment
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
pruned_model, dummy_input, "resnet50_pruned.onnx",
input_names=["input"], output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
)
# Results:
# Original: 97MB, 4.1 GFLOPs, 76.1% accuracy
# 30% pruned: 68MB, 2.9 GFLOPs, 75.4% accuracy (after fine-tuning)
# 50% pruned: 49MB, 2.1 GFLOPs, 74.2% accuracy (after fine-tuning)
Knowledge Distillation
Distillation trains a small "student" model to mimic a large "teacher" model. The student learns the teacher's soft probability distributions, which contain more information than hard labels:
import torch
import torch.nn.functional as F
class DistillationTrainer:
"""
Train a small student model to mimic a large teacher model.
The student learns from both the true labels AND the teacher's
soft predictions (which encode inter-class relationships).
"""
def __init__(self, teacher, student, temperature=4.0, alpha=0.7):
self.teacher = teacher.eval() # Frozen, no gradient
self.student = student
self.temperature = temperature # Higher = softer distributions
self.alpha = alpha # Weight of distillation vs hard loss
def distillation_loss(self, student_logits, teacher_logits, true_labels):
# Soft loss: KL divergence between teacher and student soft outputs
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
soft_loss = F.kl_div(
soft_student, soft_teacher, reduction="batchmean"
) * (self.temperature ** 2)
# Hard loss: standard cross-entropy with true labels
hard_loss = F.cross_entropy(student_logits, true_labels)
# Combined loss
return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
def train_epoch(self, dataloader, optimizer):
self.student.train()
total_loss = 0
for images, labels in dataloader:
# Teacher prediction (no gradient needed)
with torch.no_grad():
teacher_logits = self.teacher(images)
# Student prediction
student_logits = self.student(images)
# Combined loss
loss = self.distillation_loss(
student_logits, teacher_logits, labels
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# Usage: Distill ResNet-50 (25M params) into MobileNet-v3 (5M params)
teacher = load_model("resnet50_trained.pt") # 97MB, 76.1% acc
student = MobileNetV3Small(num_classes=1000) # 10MB
trainer = DistillationTrainer(teacher, student, temperature=4.0, alpha=0.7)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
for epoch in range(100):
loss = trainer.train_epoch(train_loader, optimizer)
val_acc = evaluate(student, val_loader)
print(f"Epoch {epoch}: loss={loss:.4f}, val_acc={val_acc:.2%}")
# Results:
# MobileNet-v3 trained from scratch: 67.4% accuracy, 10MB
# MobileNet-v3 distilled from ResNet: 72.8% accuracy, 10MB (+5.4%!)
# Teacher ResNet-50: 76.1% accuracy, 97MB
Edge-Optimized Architectures
Instead of compressing large models, start with architectures designed for edge from the ground up:
| Architecture | Params | Size (INT8) | Latency (Coral) | ImageNet Top-1 | Best For |
|---|---|---|---|---|---|
| MobileNet-v3 Small | 2.5M | 2.4MB | 3.5ms | 67.4% | Ultra-low latency, microcontrollers |
| MobileNet-v3 Large | 5.4M | 5.2MB | 6.2ms | 75.2% | Balanced accuracy/speed |
| EfficientNet-Lite0 | 4.7M | 4.4MB | 7.8ms | 75.1% | Best accuracy per FLOP |
| EfficientNet-Lite4 | 13M | 12.5MB | 52ms | 80.2% | High accuracy edge (Jetson) |
| YOLO-NAS-S | 12.2M | 12MB | 15ms (Jetson) | 47.5 mAP | Real-time object detection |
ONNX Conversion Pipeline
ONNX is the universal exchange format for edge models. Convert once, deploy to any runtime:
import torch
import onnx
import onnxruntime as ort
import numpy as np
# Step 1: Export PyTorch model to ONNX
model = load_trained_model()
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size"},
"output": {0: "batch_size"}
},
opset_version=17, # Use latest stable opset
do_constant_folding=True # Optimize constant expressions
)
# Step 2: Validate the ONNX model
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print(f"ONNX model validated: {len(onnx_model.graph.node)} ops")
# Step 3: Optimize with ONNX Runtime
# This fuses operations and optimizes the graph for inference
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.optimized_model_filepath = "model_optimized.onnx"
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession("model.onnx", sess_options)
# Step 4: Benchmark
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
import time
times = []
for _ in range(100):
start = time.perf_counter()
result = session.run(None, {"input": input_data})
times.append(time.perf_counter() - start)
print(f"Avg inference: {np.mean(times)*1000:.1f}ms")
print(f"P95 inference: {np.percentile(times, 95)*1000:.1f}ms")
# Step 5: Convert ONNX to target runtime
# TFLite: onnx2tf model.onnx -o model_tflite/
# CoreML: coremltools.convert(onnx_model)
# TensorRT: trtexec --onnx=model.onnx --saveEngine=model.engine
Size/Accuracy Trade-Off Benchmarks
Here is a comprehensive benchmark comparing optimization pipelines on ImageNet classification. Use this to choose the right approach for your constraints:
| Model + Optimization | Size | Top-1 Acc | Jetson Nano Latency | Coral Latency |
|---|---|---|---|---|
| ResNet-50 FP32 (baseline) | 97MB | 76.1% | 45ms | N/A |
| ResNet-50 PTQ INT8 | 24MB | 74.8% | 18ms | 52ms |
| ResNet-50 QAT INT8 | 24MB | 75.7% | 18ms | 52ms |
| MobileNet-v3 Large FP32 | 21MB | 75.2% | 12ms | N/A |
| MobileNet-v3 Large INT8 | 5.2MB | 73.8% | 5ms | 6.2ms |
| MobileNet-v3 Distilled + QAT INT8 | 5.2MB | 74.9% | 5ms | 6.2ms |
| EfficientNet-Lite0 INT8 | 4.4MB | 73.6% | 8ms | 7.8ms |
Key Takeaways
- PTQ is the fastest optimization path: 10 minutes of work for 4x compression with 0.5-2% accuracy loss. Start here for every edge project.
- QAT recovers most of the PTQ accuracy loss by simulating quantization during training. Use when you need every 0.5% of accuracy.
- Structured pruning removes whole filters, giving real speedups (not just compression). Combine with quantization for 8-16x total compression.
- Knowledge distillation is the most powerful technique: train a MobileNet student from a ResNet teacher to get +5% accuracy over training from scratch.
- The optimal pipeline is: cloud-train large model, distill to edge architecture, apply QAT, convert to INT8 TFLite/ONNX. This gives 10-40x compression with 2-3% accuracy loss.
What Is Next
In the next lesson, we will cover edge inference runtimes — TFLite, CoreML, ONNX Runtime, TensorRT, and OpenVINO. You will learn which runtime to use for each hardware platform and see deployment code for each.
Lilly Tech Systems