Advanced

Debugging & Optimization

Five challenges that test the debugging and optimization skills that separate senior DL engineers from junior ones. In interviews, you may be given broken code and asked to find the bug — these challenges prepare you for exactly that.

Challenge 1: Find the Bugs

Each code snippet below has one or more subtle bugs. Find them all.

import torch
import torch.nn as nn

# ---- BUG 1: The model that doesn't learn ----
class BuggyModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)  # BUG: ReLU after last layer kills gradients for classification
        return x
# FIX: Remove the final ReLU -- logits should be raw for cross-entropy

# ---- BUG 2: Training loop memory leak ----
def buggy_train(model, loader, optimizer, criterion):
    total_loss = 0.0
    for x, y in loader:
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss  # BUG: accumulating tensor (keeps computation graph!)
    return total_loss / len(loader)
# FIX: Use total_loss += loss.item()  -- .item() extracts scalar, frees graph

# ---- BUG 3: Validation with wrong mode ----
def buggy_evaluate(model, loader, criterion):
    # BUG 1: Missing model.eval() -- dropout/batchnorm behave differently
    # BUG 2: Missing torch.no_grad() -- wastes memory building computation graph
    total_loss = 0.0
    for x, y in loader:
        out = model(x)
        loss = criterion(out, y)
        total_loss += loss.item()
    return total_loss / len(loader)
# FIX:
# model.eval()
# with torch.no_grad():
#     ... (evaluation code)
# model.train()  # restore training mode after

# ---- BUG 4: Gradient accumulation gone wrong ----
def buggy_gradient_accumulation(model, loader, optimizer, criterion, accum_steps=4):
    for i, (x, y) in enumerate(loader):
        out = model(x)
        loss = criterion(out, y) / accum_steps
        loss.backward()
        if (i + 1) % accum_steps == 0:
            optimizer.step()
            # BUG: Missing optimizer.zero_grad() after step!
            # Gradients keep accumulating across mega-batches
    # FIX: Add optimizer.zero_grad() after optimizer.step()

# ---- BUG 5: The sneaky device mismatch ----
def buggy_inference(model, x):
    model = model.cuda()
    # BUG: x is still on CPU!
    out = model(x)  # RuntimeError: expected all tensors on same device
    return out
# FIX: x = x.cuda()  OR  x = x.to(next(model.parameters()).device)

print("All 5 bugs identified and fixed!")

Challenge 2: Memory Optimization

Reduce GPU memory usage for training large models — the difference between fitting on one GPU and needing four.

import torch
import torch.nn as nn

class MemoryEfficientTrainer:
    """
    Challenge: Implement memory optimization techniques:
    1. Gradient checkpointing (trade compute for memory)
    2. In-place operations where safe
    3. Clearing intermediate activations
    4. Monitoring GPU memory usage
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
from torch.utils.checkpoint import checkpoint

class CheckpointedTransformerBlock(nn.Module):
    """Uses gradient checkpointing to reduce memory at the cost of compute."""
    def __init__(self, d_model, num_heads, use_checkpoint=True):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )

    def _attn_block(self, x):
        normed = self.norm1(x)
        attn_out, _ = self.attn(normed, normed, normed)
        return x + attn_out

    def _ffn_block(self, x):
        return x + self.ffn(self.norm2(x))

    def forward(self, x):
        if self.use_checkpoint and self.training:
            # Gradient checkpointing: don't store activations during forward
            # Recompute them during backward -- saves ~60% memory per layer
            x = checkpoint(self._attn_block, x, use_reentrant=False)
            x = checkpoint(self._ffn_block, x, use_reentrant=False)
        else:
            x = self._attn_block(x)
            x = self._ffn_block(x)
        return x


def get_gpu_memory_stats():
    """Monitor GPU memory usage."""
    if not torch.cuda.is_available():
        return {"allocated": 0, "reserved": 0, "max_allocated": 0}
    return {
        "allocated_mb": torch.cuda.memory_allocated() / 1024**2,
        "reserved_mb": torch.cuda.memory_reserved() / 1024**2,
        "max_allocated_mb": torch.cuda.max_memory_allocated() / 1024**2,
    }


def memory_efficient_inference(model, data_loader, device):
    """Memory-efficient inference for large datasets."""
    model.eval()
    all_predictions = []

    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            preds = model(batch)
            # Move predictions to CPU immediately to free GPU memory
            all_predictions.append(preds.cpu())
            # Delete GPU tensors
            del batch, preds

    # Force GPU memory cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return torch.cat(all_predictions, dim=0)


# Test
B, S, D, H = 2, 16, 64, 4
block_ckpt = CheckpointedTransformerBlock(D, H, use_checkpoint=True)
block_no_ckpt = CheckpointedTransformerBlock(D, H, use_checkpoint=False)

x = torch.randn(B, S, D, requires_grad=True)
out = block_ckpt(x)
out.sum().backward()
print(f"With checkpointing - output: {out.shape}")

stats = get_gpu_memory_stats()
print(f"GPU memory: {stats}")

Challenge 3: Profiling

Profile a model to find performance bottlenecks — the first step in any optimization effort.

import torch
import torch.nn as nn
import time
from contextlib import contextmanager

@contextmanager
def timer(label=""):
    """Simple timing context manager."""
    start = time.perf_counter()
    yield
    elapsed = time.perf_counter() - start
    print(f"{label}: {elapsed*1000:.2f} ms")

class ModelProfiler:
    """
    Challenge: Build a profiler that measures:
    1. Forward pass time per layer
    2. Backward pass time
    3. Parameter count per layer
    4. Activation memory per layer
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
class ModelProfiler:
    def __init__(self, model):
        self.model = model
        self.layer_times = {}
        self.layer_output_shapes = {}
        self.hooks = []

    def _make_forward_hook(self, name):
        def hook(module, input, output):
            if hasattr(module, '_start_time'):
                elapsed = time.perf_counter() - module._start_time
                self.layer_times[name] = elapsed * 1000  # ms
            if isinstance(output, torch.Tensor):
                self.layer_output_shapes[name] = tuple(output.shape)
        return hook

    def _make_forward_pre_hook(self, name):
        def hook(module, input):
            module._start_time = time.perf_counter()
        return hook

    def register_hooks(self):
        """Register timing hooks on all layers."""
        for name, module in self.model.named_modules():
            if name == '':
                continue
            self.hooks.append(module.register_forward_pre_hook(
                self._make_forward_pre_hook(name)
            ))
            self.hooks.append(module.register_forward_hook(
                self._make_forward_hook(name)
            ))

    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

    def profile(self, input_tensor, num_runs=10):
        """Run profiling and return results."""
        self.register_hooks()

        # Warmup
        with torch.no_grad():
            for _ in range(3):
                self.model(input_tensor)

        # Profile forward pass
        forward_times = []
        for _ in range(num_runs):
            start = time.perf_counter()
            with torch.no_grad():
                self.model(input_tensor)
            forward_times.append((time.perf_counter() - start) * 1000)

        # Profile backward pass
        backward_times = []
        for _ in range(num_runs):
            out = self.model(input_tensor)
            start = time.perf_counter()
            out.sum().backward()
            backward_times.append((time.perf_counter() - start) * 1000)
            self.model.zero_grad()

        self.remove_hooks()

        return {
            'forward_ms': sum(forward_times) / num_runs,
            'backward_ms': sum(backward_times) / num_runs,
            'layer_times': dict(self.layer_times),
            'layer_shapes': dict(self.layer_output_shapes),
            'param_count': self.count_parameters(),
        }

    def count_parameters(self):
        """Count parameters per layer."""
        counts = {}
        total = 0
        for name, param in self.model.named_parameters():
            counts[name] = param.numel()
            total += param.numel()
        counts['_total'] = total
        return counts

# Test
model = nn.Sequential(
    nn.Linear(256, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 10),
)
x = torch.randn(32, 256)

profiler = ModelProfiler(model)
results = profiler.profile(x, num_runs=5)
print(f"Forward:  {results['forward_ms']:.2f} ms")
print(f"Backward: {results['backward_ms']:.2f} ms")
print(f"Total params: {results['param_count']['_total']:,}")
for name, shape in results['layer_shapes'].items():
    t = results['layer_times'].get(name, 0)
    print(f"  {name}: {shape} ({t:.3f} ms)")

Challenge 4: Gradient Checking

Implement numerical gradient checking to verify your custom autograd functions and layers.

import torch
import torch.nn as nn

def numerical_gradient_check(func, inputs, eps=1e-5, atol=1e-4, rtol=1e-3):
    """
    Challenge: Implement numerical gradient checking.
    - Compute analytical gradients via autograd
    - Compute numerical gradients via finite differences
    - Compare and report mismatches
    - Handle multi-input, multi-output functions
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
def numerical_gradient_check(func, inputs, eps=1e-5, atol=1e-4, rtol=1e-3):
    """
    Verify analytical gradients match numerical gradients.
    func: callable that takes *inputs and returns a scalar
    inputs: list of tensors (will compute gradients for those with requires_grad=True)
    """
    # Compute analytical gradients
    inputs_clone = [x.clone().detach().requires_grad_(x.requires_grad) for x in inputs]
    output = func(*inputs_clone)
    output.backward()

    results = []

    for i, (inp, inp_clone) in enumerate(zip(inputs, inputs_clone)):
        if not inp.requires_grad:
            continue

        analytical_grad = inp_clone.grad.clone()
        numerical_grad = torch.zeros_like(inp)

        # Compute numerical gradient via central differences
        flat_inp = inp.detach().clone().view(-1)
        for j in range(flat_inp.numel()):
            # f(x + eps)
            flat_inp_plus = flat_inp.clone()
            flat_inp_plus[j] += eps
            inp_plus = flat_inp_plus.view_as(inp)
            inputs_plus = [inp_plus if k == i else x.detach() for k, x in enumerate(inputs)]
            out_plus = func(*inputs_plus)

            # f(x - eps)
            flat_inp_minus = flat_inp.clone()
            flat_inp_minus[j] -= eps
            inp_minus = flat_inp_minus.view_as(inp)
            inputs_minus = [inp_minus if k == i else x.detach() for k, x in enumerate(inputs)]
            out_minus = func(*inputs_minus)

            # Central difference: (f(x+eps) - f(x-eps)) / (2*eps)
            numerical_grad.view(-1)[j] = (out_plus.item() - out_minus.item()) / (2 * eps)

        # Compare
        max_diff = (analytical_grad - numerical_grad).abs().max().item()
        rel_diff = max_diff / (analytical_grad.abs().max().item() + 1e-8)
        passed = max_diff < atol or rel_diff < rtol

        results.append({
            'input_idx': i,
            'max_abs_diff': max_diff,
            'max_rel_diff': rel_diff,
            'passed': passed,
        })

        status = "PASS" if passed else "FAIL"
        print(f"Input {i}: {status} (max abs diff: {max_diff:.2e}, max rel diff: {rel_diff:.2e})")

    return all(r['passed'] for r in results)


# Test with a custom function
def my_function(x, w):
    return (x @ w).pow(2).sum()

x = torch.randn(4, 3, requires_grad=True)
w = torch.randn(3, 5, requires_grad=True)
passed = numerical_gradient_check(my_function, [x, w])
print(f"All gradients correct: {passed}")

Challenge 5: NaN Detection & Recovery

Build a system that detects NaN values during training and helps diagnose the root cause.

import torch
import torch.nn as nn

class NaNDetector:
    """
    Challenge: Build a NaN detection system that:
    1. Hooks into all layers to detect NaN activations
    2. Checks gradients after backward pass
    3. Identifies which layer first produces NaN
    4. Provides diagnostic information
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
class NaNDetector:
    def __init__(self, model):
        self.model = model
        self.hooks = []
        self.nan_detected = False
        self.nan_layer = None
        self.nan_type = None  # 'activation' or 'gradient'

    def _check_tensor(self, tensor, name, tensor_type):
        if tensor is None:
            return False
        if isinstance(tensor, tuple):
            return any(self._check_tensor(t, f"{name}[{i}]", tensor_type)
                      for i, t in enumerate(tensor))
        if isinstance(tensor, torch.Tensor):
            has_nan = torch.isnan(tensor).any().item()
            has_inf = torch.isinf(tensor).any().item()
            if has_nan or has_inf:
                if not self.nan_detected:
                    self.nan_detected = True
                    self.nan_layer = name
                    self.nan_type = tensor_type
                    issue = "NaN" if has_nan else "Inf"
                    print(f"[NaN DETECTOR] {issue} detected in {tensor_type} "
                          f"at layer '{name}'")
                    print(f"  Shape: {tensor.shape}")
                    print(f"  Min: {tensor.min().item():.6f}")
                    print(f"  Max: {tensor.max().item():.6f}")
                    print(f"  NaN count: {torch.isnan(tensor).sum().item()}")
                    print(f"  Inf count: {torch.isinf(tensor).sum().item()}")
                return True
        return False

    def register_hooks(self):
        for name, module in self.model.named_modules():
            if name == '':
                continue

            # Forward hook: check activations
            def forward_hook(mod, inp, out, name=name):
                self._check_tensor(out, name, 'activation')
            self.hooks.append(module.register_forward_hook(forward_hook))

            # Backward hook: check gradients
            def backward_hook(mod, grad_input, grad_output, name=name):
                self._check_tensor(grad_output, name, 'gradient (output)')
                self._check_tensor(grad_input, name, 'gradient (input)')
            self.hooks.append(module.register_full_backward_hook(backward_hook))

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks = []

    def check_parameters(self):
        """Check all parameters and their gradients for NaN."""
        issues = []
        for name, param in self.model.named_parameters():
            if torch.isnan(param.data).any():
                issues.append(f"NaN in parameter: {name}")
            if param.grad is not None and torch.isnan(param.grad).any():
                issues.append(f"NaN in gradient: {name}")
        return issues

    def diagnose(self):
        """Common causes and fixes for NaN."""
        print("\n--- NaN Diagnosis Checklist ---")
        checks = [
            ("Learning rate too high", "Try reducing LR by 10x"),
            ("Missing gradient clipping", "Add torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)"),
            ("Log of zero or negative", "Add eps: torch.log(x + 1e-8)"),
            ("Division by zero", "Add eps to denominator: x / (y + 1e-8)"),
            ("Exploding activations", "Check for missing normalization layers"),
            ("Unstable loss function", "Use log_softmax instead of log(softmax)"),
            ("Bad weight initialization", "Use proper init: kaiming, xavier"),
        ]
        for cause, fix in checks:
            print(f"  - {cause}: {fix}")

# Test
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5),
)

detector = NaNDetector(model)
detector.register_hooks()

# Normal forward pass (no NaN expected)
x = torch.randn(4, 10)
out = model(x)
out.sum().backward()
print(f"NaN detected: {detector.nan_detected}")

# Check parameters
issues = detector.check_parameters()
print(f"Parameter issues: {issues}")

detector.remove_hooks()

# Show diagnosis
detector.diagnose()
💡
Interview insight: When given buggy training code, check these five things in order: (1) Is the model in the right mode (train vs eval)? (2) Are gradients being zeroed? (3) Is data on the correct device? (4) Are intermediate results being detached where needed (.item() for logging)? (5) Is the loss function getting valid inputs (no log of zero, no division by zero)? Verbalizing this systematic approach scores points even if the bug is simple.