Intermediate

Training Loops

Five challenges covering production-grade training infrastructure. These are the patterns you will write on day one at any DL team, and interviewers expect you to know them from memory.

Challenge 1: Complete Training Loop

Write a complete training and evaluation loop with all the components a production system needs.

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

def train_model(model, train_loader, val_loader, epochs, lr, device):
    """
    Challenge: Write a complete training loop with:
    - Optimizer setup (AdamW)
    - Training phase with model.train()
    - Validation phase with model.eval() and torch.no_grad()
    - Loss and accuracy tracking
    - Best model saving based on validation loss
    - Return training history
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
def train_model(model, train_loader, val_loader, epochs, lr, device):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float('inf')
    best_state = None
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(epochs):
        # ---- Training Phase ----
        model.train()
        total_train_loss = 0.0
        num_train_batches = 0

        for batch_x, batch_y in train_loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            optimizer.zero_grad()              # 1. Zero gradients
            logits = model(batch_x)            # 2. Forward pass
            loss = criterion(logits, batch_y)  # 3. Compute loss
            loss.backward()                    # 4. Backward pass
            optimizer.step()                   # 5. Update weights

            total_train_loss += loss.item()
            num_train_batches += 1

        avg_train_loss = total_train_loss / num_train_batches

        # ---- Validation Phase ----
        model.eval()
        total_val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():                 # No gradient computation
            for batch_x, batch_y in val_loader:
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device)

                logits = model(batch_x)
                loss = criterion(logits, batch_y)
                total_val_loss += loss.item()

                preds = logits.argmax(dim=1)
                correct += (preds == batch_y).sum().item()
                total += batch_y.size(0)

        avg_val_loss = total_val_loss / len(val_loader)
        val_acc = correct / total

        # ---- Track History ----
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_acc)

        # ---- Save Best Model ----
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_state = {k: v.clone() for k, v in model.state_dict().items()}

        print(f"Epoch {epoch+1}/{epochs} | "
              f"Train Loss: {avg_train_loss:.4f} | "
              f"Val Loss: {avg_val_loss:.4f} | "
              f"Val Acc: {val_acc:.4f}")

    # Restore best model
    if best_state is not None:
        model.load_state_dict(best_state)

    return history

# Test with dummy data
device = torch.device('cpu')
model = nn.Sequential(nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 5))
X_train = torch.randn(200, 20)
y_train = torch.randint(0, 5, (200,))
X_val = torch.randn(50, 20)
y_val = torch.randint(0, 5, (50,))
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=32)
history = train_model(model, train_loader, val_loader, epochs=3, lr=1e-3, device=device)

Challenge 2: Learning Rate Scheduling

Implement warmup + cosine decay scheduling — the standard for training transformers.

import torch
import math

class WarmupCosineScheduler:
    """
    Challenge: Implement warmup + cosine annealing LR scheduler.
    - Linear warmup from 0 to max_lr over warmup_steps
    - Cosine decay from max_lr to min_lr over remaining steps
    - Must work with optimizer.param_groups
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_steps, total_steps, max_lr, min_lr=0.0):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.current_step = 0

    def get_lr(self):
        if self.current_step < self.warmup_steps:
            # Linear warmup
            return self.max_lr * (self.current_step / max(1, self.warmup_steps))
        else:
            # Cosine decay
            progress = (self.current_step - self.warmup_steps) / max(
                1, self.total_steps - self.warmup_steps
            )
            return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (
                1 + math.cos(math.pi * progress)
            )

    def step(self):
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        self.current_step += 1
        return lr

# Test
import torch.nn as nn
model = nn.Linear(10, 5)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0)
scheduler = WarmupCosineScheduler(optimizer, warmup_steps=100, total_steps=1000, max_lr=1e-3)

# Print LR at key points
for step in range(1000):
    lr = scheduler.step()
    if step in [0, 50, 100, 500, 999]:
        print(f"Step {step:4d}: lr = {lr:.6f}")

Challenge 3: Gradient Clipping

Implement gradient clipping by norm and by value — essential for training RNNs and large transformers.

import torch
import torch.nn as nn

def clip_grad_norm_manual(parameters, max_norm):
    """
    Challenge: Implement torch.nn.utils.clip_grad_norm_ from scratch.
    - Compute the total gradient norm across all parameters
    - If total norm > max_norm, scale all gradients down proportionally
    - Return the total norm (before clipping)
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
def clip_grad_norm_manual(parameters, max_norm):
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    if len(parameters) == 0:
        return torch.tensor(0.0)

    # 1. Compute total gradient norm (L2)
    total_norm_sq = sum(p.grad.data.norm(2).item() ** 2 for p in parameters)
    total_norm = total_norm_sq ** 0.5

    # 2. Compute clipping coefficient
    clip_coef = max_norm / (total_norm + 1e-6)

    # 3. Scale gradients if needed
    if clip_coef < 1.0:
        for p in parameters:
            p.grad.data.mul_(clip_coef)

    return total_norm

# Test
model = nn.Linear(10, 5)
x = torch.randn(4, 10)
loss = model(x).sum()
loss.backward()

# Check gradient norm before and after clipping
norm_before = sum(p.grad.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
print(f"Grad norm before clipping: {norm_before:.4f}")

norm_returned = clip_grad_norm_manual(model.parameters(), max_norm=1.0)
print(f"Grad norm returned:        {norm_returned:.4f}")

norm_after = sum(p.grad.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
print(f"Grad norm after clipping:  {norm_after:.4f}")  # should be <= 1.0

Challenge 4: Mixed Precision Training

Implement training with automatic mixed precision (AMP) using torch.cuda.amp — the standard for training large models efficiently.

import torch
import torch.nn as nn

def train_with_amp(model, train_loader, epochs, lr, device):
    """
    Challenge: Implement mixed precision training with:
    - torch.cuda.amp.autocast for forward pass
    - torch.cuda.amp.GradScaler for gradient scaling
    - Proper scaler.scale(), scaler.step(), scaler.update() sequence
    - Handle inf/nan gradients (scaler does this automatically)
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
def train_with_amp(model, train_loader, epochs, lr, device):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # AMP components
    scaler = torch.amp.GradScaler(device.type)

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for batch_x, batch_y in train_loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            optimizer.zero_grad()

            # Forward pass in mixed precision
            with torch.amp.autocast(device_type=device.type):
                logits = model(batch_x)
                loss = criterion(logits, batch_y)

            # Backward pass with gradient scaling
            scaler.scale(loss).backward()    # Scale loss to prevent underflow
            scaler.step(optimizer)           # Unscale gradients and step
            scaler.update()                  # Update scale factor

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f} | "
              f"Scale: {scaler.get_scale():.0f}")

    return model

# Note: AMP provides ~2x speedup on NVIDIA GPUs with Tensor Cores
# On CPU, autocast still works but provides no speedup
# The key insight: forward pass uses float16, backward uses float32
# GradScaler prevents float16 gradient underflow

# Test (CPU demo - real speedup requires GPU)
from torch.utils.data import DataLoader, TensorDataset
model = nn.Sequential(nn.Linear(20, 128), nn.ReLU(), nn.Linear(128, 10))
X = torch.randn(100, 20)
y = torch.randint(0, 10, (100,))
loader = DataLoader(TensorDataset(X, y), batch_size=32)
train_with_amp(model, loader, epochs=2, lr=1e-3, device=torch.device('cpu'))

Challenge 5: Model Checkpointing

Implement a complete checkpointing system that saves and resumes training state.

import torch
import torch.nn as nn
import os

class CheckpointManager:
    """
    Challenge: Implement a checkpoint manager that:
    - Saves model, optimizer, scheduler, epoch, and best metric
    - Supports resume from checkpoint
    - Keeps only top-K checkpoints by metric
    - Handles interrupted training gracefully
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
class CheckpointManager:
    def __init__(self, save_dir, max_checkpoints=3):
        self.save_dir = save_dir
        self.max_checkpoints = max_checkpoints
        self.checkpoints = []  # list of (metric, path)
        os.makedirs(save_dir, exist_ok=True)

    def save(self, model, optimizer, scheduler, epoch, metric, extra=None):
        """Save a checkpoint."""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metric': metric,
        }
        if scheduler is not None:
            checkpoint['scheduler_state_dict'] = scheduler.state_dict()
        if extra is not None:
            checkpoint.update(extra)

        path = os.path.join(self.save_dir, f'checkpoint_epoch{epoch}.pt')
        torch.save(checkpoint, path)
        self.checkpoints.append((metric, path))

        # Keep only top-K checkpoints (lowest metric = best)
        self.checkpoints.sort(key=lambda x: x[0])
        while len(self.checkpoints) > self.max_checkpoints:
            _, remove_path = self.checkpoints.pop()
            if os.path.exists(remove_path):
                os.remove(remove_path)

        print(f"Saved checkpoint: epoch={epoch}, metric={metric:.4f}")
        return path

    def load_best(self, model, optimizer=None, scheduler=None):
        """Load the best checkpoint."""
        if not self.checkpoints:
            # Try to find checkpoints on disk
            self._scan_checkpoints()
        if not self.checkpoints:
            print("No checkpoints found.")
            return 0

        best_metric, best_path = self.checkpoints[0]
        return self._load(best_path, model, optimizer, scheduler)

    def load_latest(self, model, optimizer=None, scheduler=None):
        """Load the most recent checkpoint for resuming training."""
        if not self.checkpoints:
            self._scan_checkpoints()
        if not self.checkpoints:
            print("No checkpoints found.")
            return 0

        # Sort by epoch (extracted from filename)
        latest_path = max(
            [p for _, p in self.checkpoints],
            key=lambda p: int(p.split('epoch')[1].split('.')[0])
        )
        return self._load(latest_path, model, optimizer, scheduler)

    def _load(self, path, model, optimizer=None, scheduler=None):
        checkpoint = torch.load(path, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        if optimizer and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        print(f"Loaded checkpoint: epoch={checkpoint['epoch']}, "
              f"metric={checkpoint['metric']:.4f}")
        return checkpoint['epoch']

    def _scan_checkpoints(self):
        """Scan save_dir for existing checkpoints."""
        for f in os.listdir(self.save_dir):
            if f.startswith('checkpoint_') and f.endswith('.pt'):
                path = os.path.join(self.save_dir, f)
                ckpt = torch.load(path, weights_only=False)
                self.checkpoints.append((ckpt['metric'], path))
        self.checkpoints.sort(key=lambda x: x[0])

# Test
model = nn.Linear(10, 5)
optimizer = torch.optim.Adam(model.parameters())
manager = CheckpointManager('/tmp/checkpoints', max_checkpoints=2)

# Simulate training
for epoch in range(5):
    val_loss = 1.0 - epoch * 0.15  # decreasing loss
    manager.save(model, optimizer, None, epoch, val_loss)

# Load best
start_epoch = manager.load_best(model)
print(f"Resumed from epoch: {start_epoch}")
💡
Interview tip: The five critical things interviewers check in your training loop: (1) model.train() and model.eval() mode switching, (2) optimizer.zero_grad() before backward, (3) torch.no_grad() during validation, (4) gradient clipping before optimizer step, (5) moving data to the correct device. Missing any of these is an immediate red flag.