Intermediate
Training Loops
Five challenges covering production-grade training infrastructure. These are the patterns you will write on day one at any DL team, and interviewers expect you to know them from memory.
Challenge 1: Complete Training Loop
Write a complete training and evaluation loop with all the components a production system needs.
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
def train_model(model, train_loader, val_loader, epochs, lr, device):
"""
Challenge: Write a complete training loop with:
- Optimizer setup (AdamW)
- Training phase with model.train()
- Validation phase with model.eval() and torch.no_grad()
- Loss and accuracy tracking
- Best model saving based on validation loss
- Return training history
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
def train_model(model, train_loader, val_loader, epochs, lr, device):
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
best_val_loss = float('inf')
best_state = None
history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
for epoch in range(epochs):
# ---- Training Phase ----
model.train()
total_train_loss = 0.0
num_train_batches = 0
for batch_x, batch_y in train_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
optimizer.zero_grad() # 1. Zero gradients
logits = model(batch_x) # 2. Forward pass
loss = criterion(logits, batch_y) # 3. Compute loss
loss.backward() # 4. Backward pass
optimizer.step() # 5. Update weights
total_train_loss += loss.item()
num_train_batches += 1
avg_train_loss = total_train_loss / num_train_batches
# ---- Validation Phase ----
model.eval()
total_val_loss = 0.0
correct = 0
total = 0
with torch.no_grad(): # No gradient computation
for batch_x, batch_y in val_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
logits = model(batch_x)
loss = criterion(logits, batch_y)
total_val_loss += loss.item()
preds = logits.argmax(dim=1)
correct += (preds == batch_y).sum().item()
total += batch_y.size(0)
avg_val_loss = total_val_loss / len(val_loader)
val_acc = correct / total
# ---- Track History ----
history['train_loss'].append(avg_train_loss)
history['val_loss'].append(avg_val_loss)
history['val_acc'].append(val_acc)
# ---- Save Best Model ----
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
best_state = {k: v.clone() for k, v in model.state_dict().items()}
print(f"Epoch {epoch+1}/{epochs} | "
f"Train Loss: {avg_train_loss:.4f} | "
f"Val Loss: {avg_val_loss:.4f} | "
f"Val Acc: {val_acc:.4f}")
# Restore best model
if best_state is not None:
model.load_state_dict(best_state)
return history
# Test with dummy data
device = torch.device('cpu')
model = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 5))
X_train = torch.randn(200, 20)
y_train = torch.randint(0, 5, (200,))
X_val = torch.randn(50, 20)
y_val = torch.randint(0, 5, (50,))
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=32)
history = train_model(model, train_loader, val_loader, epochs=3, lr=1e-3, device=device)
Challenge 2: Learning Rate Scheduling
Implement warmup + cosine decay scheduling — the standard for training transformers.
import torch
import math
class WarmupCosineScheduler:
"""
Challenge: Implement warmup + cosine annealing LR scheduler.
- Linear warmup from 0 to max_lr over warmup_steps
- Cosine decay from max_lr to min_lr over remaining steps
- Must work with optimizer.param_groups
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
class WarmupCosineScheduler:
def __init__(self, optimizer, warmup_steps, total_steps, max_lr, min_lr=0.0):
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.max_lr = max_lr
self.min_lr = min_lr
self.current_step = 0
def get_lr(self):
if self.current_step < self.warmup_steps:
# Linear warmup
return self.max_lr * (self.current_step / max(1, self.warmup_steps))
else:
# Cosine decay
progress = (self.current_step - self.warmup_steps) / max(
1, self.total_steps - self.warmup_steps
)
return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (
1 + math.cos(math.pi * progress)
)
def step(self):
lr = self.get_lr()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.current_step += 1
return lr
# Test
import torch.nn as nn
model = nn.Linear(10, 5)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0)
scheduler = WarmupCosineScheduler(optimizer, warmup_steps=100, total_steps=1000, max_lr=1e-3)
# Print LR at key points
for step in range(1000):
lr = scheduler.step()
if step in [0, 50, 100, 500, 999]:
print(f"Step {step:4d}: lr = {lr:.6f}")
Challenge 3: Gradient Clipping
Implement gradient clipping by norm and by value — essential for training RNNs and large transformers.
import torch
import torch.nn as nn
def clip_grad_norm_manual(parameters, max_norm):
"""
Challenge: Implement torch.nn.utils.clip_grad_norm_ from scratch.
- Compute the total gradient norm across all parameters
- If total norm > max_norm, scale all gradients down proportionally
- Return the total norm (before clipping)
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
def clip_grad_norm_manual(parameters, max_norm):
parameters = list(filter(lambda p: p.grad is not None, parameters))
if len(parameters) == 0:
return torch.tensor(0.0)
# 1. Compute total gradient norm (L2)
total_norm_sq = sum(p.grad.data.norm(2).item() ** 2 for p in parameters)
total_norm = total_norm_sq ** 0.5
# 2. Compute clipping coefficient
clip_coef = max_norm / (total_norm + 1e-6)
# 3. Scale gradients if needed
if clip_coef < 1.0:
for p in parameters:
p.grad.data.mul_(clip_coef)
return total_norm
# Test
model = nn.Linear(10, 5)
x = torch.randn(4, 10)
loss = model(x).sum()
loss.backward()
# Check gradient norm before and after clipping
norm_before = sum(p.grad.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
print(f"Grad norm before clipping: {norm_before:.4f}")
norm_returned = clip_grad_norm_manual(model.parameters(), max_norm=1.0)
print(f"Grad norm returned: {norm_returned:.4f}")
norm_after = sum(p.grad.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
print(f"Grad norm after clipping: {norm_after:.4f}") # should be <= 1.0
Challenge 4: Mixed Precision Training
Implement training with automatic mixed precision (AMP) using torch.cuda.amp — the standard for training large models efficiently.
import torch
import torch.nn as nn
def train_with_amp(model, train_loader, epochs, lr, device):
"""
Challenge: Implement mixed precision training with:
- torch.cuda.amp.autocast for forward pass
- torch.cuda.amp.GradScaler for gradient scaling
- Proper scaler.scale(), scaler.step(), scaler.update() sequence
- Handle inf/nan gradients (scaler does this automatically)
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
def train_with_amp(model, train_loader, epochs, lr, device):
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
# AMP components
scaler = torch.amp.GradScaler(device.type)
for epoch in range(epochs):
model.train()
total_loss = 0.0
for batch_x, batch_y in train_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
optimizer.zero_grad()
# Forward pass in mixed precision
with torch.amp.autocast(device_type=device.type):
logits = model(batch_x)
loss = criterion(logits, batch_y)
# Backward pass with gradient scaling
scaler.scale(loss).backward() # Scale loss to prevent underflow
scaler.step(optimizer) # Unscale gradients and step
scaler.update() # Update scale factor
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f} | "
f"Scale: {scaler.get_scale():.0f}")
return model
# Note: AMP provides ~2x speedup on NVIDIA GPUs with Tensor Cores
# On CPU, autocast still works but provides no speedup
# The key insight: forward pass uses float16, backward uses float32
# GradScaler prevents float16 gradient underflow
# Test (CPU demo - real speedup requires GPU)
from torch.utils.data import DataLoader, TensorDataset
model = nn.Sequential(nn.Linear(20, 128), nn.ReLU(), nn.Linear(128, 10))
X = torch.randn(100, 20)
y = torch.randint(0, 10, (100,))
loader = DataLoader(TensorDataset(X, y), batch_size=32)
train_with_amp(model, loader, epochs=2, lr=1e-3, device=torch.device('cpu'))
Challenge 5: Model Checkpointing
Implement a complete checkpointing system that saves and resumes training state.
import torch
import torch.nn as nn
import os
class CheckpointManager:
"""
Challenge: Implement a checkpoint manager that:
- Saves model, optimizer, scheduler, epoch, and best metric
- Supports resume from checkpoint
- Keeps only top-K checkpoints by metric
- Handles interrupted training gracefully
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
class CheckpointManager:
def __init__(self, save_dir, max_checkpoints=3):
self.save_dir = save_dir
self.max_checkpoints = max_checkpoints
self.checkpoints = [] # list of (metric, path)
os.makedirs(save_dir, exist_ok=True)
def save(self, model, optimizer, scheduler, epoch, metric, extra=None):
"""Save a checkpoint."""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'metric': metric,
}
if scheduler is not None:
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
if extra is not None:
checkpoint.update(extra)
path = os.path.join(self.save_dir, f'checkpoint_epoch{epoch}.pt')
torch.save(checkpoint, path)
self.checkpoints.append((metric, path))
# Keep only top-K checkpoints (lowest metric = best)
self.checkpoints.sort(key=lambda x: x[0])
while len(self.checkpoints) > self.max_checkpoints:
_, remove_path = self.checkpoints.pop()
if os.path.exists(remove_path):
os.remove(remove_path)
print(f"Saved checkpoint: epoch={epoch}, metric={metric:.4f}")
return path
def load_best(self, model, optimizer=None, scheduler=None):
"""Load the best checkpoint."""
if not self.checkpoints:
# Try to find checkpoints on disk
self._scan_checkpoints()
if not self.checkpoints:
print("No checkpoints found.")
return 0
best_metric, best_path = self.checkpoints[0]
return self._load(best_path, model, optimizer, scheduler)
def load_latest(self, model, optimizer=None, scheduler=None):
"""Load the most recent checkpoint for resuming training."""
if not self.checkpoints:
self._scan_checkpoints()
if not self.checkpoints:
print("No checkpoints found.")
return 0
# Sort by epoch (extracted from filename)
latest_path = max(
[p for _, p in self.checkpoints],
key=lambda p: int(p.split('epoch')[1].split('.')[0])
)
return self._load(latest_path, model, optimizer, scheduler)
def _load(self, path, model, optimizer=None, scheduler=None):
checkpoint = torch.load(path, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if scheduler and 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
print(f"Loaded checkpoint: epoch={checkpoint['epoch']}, "
f"metric={checkpoint['metric']:.4f}")
return checkpoint['epoch']
def _scan_checkpoints(self):
"""Scan save_dir for existing checkpoints."""
for f in os.listdir(self.save_dir):
if f.startswith('checkpoint_') and f.endswith('.pt'):
path = os.path.join(self.save_dir, f)
ckpt = torch.load(path, weights_only=False)
self.checkpoints.append((ckpt['metric'], path))
self.checkpoints.sort(key=lambda x: x[0])
# Test
model = nn.Linear(10, 5)
optimizer = torch.optim.Adam(model.parameters())
manager = CheckpointManager('/tmp/checkpoints', max_checkpoints=2)
# Simulate training
for epoch in range(5):
val_loss = 1.0 - epoch * 0.15 # decreasing loss
manager.save(model, optimizer, None, epoch, val_loss)
# Load best
start_epoch = manager.load_best(model)
print(f"Resumed from epoch: {start_epoch}")
Interview tip: The five critical things interviewers check in your training loop: (1)
model.train() and model.eval() mode switching, (2) optimizer.zero_grad() before backward, (3) torch.no_grad() during validation, (4) gradient clipping before optimizer step, (5) moving data to the correct device. Missing any of these is an immediate red flag.
Lilly Tech Systems