Advanced
Debugging & Optimization
Five challenges that test the debugging and optimization skills that separate senior DL engineers from junior ones. In interviews, you may be given broken code and asked to find the bug — these challenges prepare you for exactly that.
Challenge 1: Find the Bugs
Each code snippet below has one or more subtle bugs. Find them all.
import torch
import torch.nn as nn
# ---- BUG 1: The model that doesn't learn ----
class BuggyModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x) # BUG: ReLU after last layer kills gradients for classification
return x
# FIX: Remove the final ReLU -- logits should be raw for cross-entropy
# ---- BUG 2: Training loop memory leak ----
def buggy_train(model, loader, optimizer, criterion):
total_loss = 0.0
for x, y in loader:
optimizer.zero_grad()
out = model(x)
loss = criterion(out, y)
loss.backward()
optimizer.step()
total_loss += loss # BUG: accumulating tensor (keeps computation graph!)
return total_loss / len(loader)
# FIX: Use total_loss += loss.item() -- .item() extracts scalar, frees graph
# ---- BUG 3: Validation with wrong mode ----
def buggy_evaluate(model, loader, criterion):
# BUG 1: Missing model.eval() -- dropout/batchnorm behave differently
# BUG 2: Missing torch.no_grad() -- wastes memory building computation graph
total_loss = 0.0
for x, y in loader:
out = model(x)
loss = criterion(out, y)
total_loss += loss.item()
return total_loss / len(loader)
# FIX:
# model.eval()
# with torch.no_grad():
# ... (evaluation code)
# model.train() # restore training mode after
# ---- BUG 4: Gradient accumulation gone wrong ----
def buggy_gradient_accumulation(model, loader, optimizer, criterion, accum_steps=4):
for i, (x, y) in enumerate(loader):
out = model(x)
loss = criterion(out, y) / accum_steps
loss.backward()
if (i + 1) % accum_steps == 0:
optimizer.step()
# BUG: Missing optimizer.zero_grad() after step!
# Gradients keep accumulating across mega-batches
# FIX: Add optimizer.zero_grad() after optimizer.step()
# ---- BUG 5: The sneaky device mismatch ----
def buggy_inference(model, x):
model = model.cuda()
# BUG: x is still on CPU!
out = model(x) # RuntimeError: expected all tensors on same device
return out
# FIX: x = x.cuda() OR x = x.to(next(model.parameters()).device)
print("All 5 bugs identified and fixed!")
Challenge 2: Memory Optimization
Reduce GPU memory usage for training large models — the difference between fitting on one GPU and needing four.
import torch
import torch.nn as nn
class MemoryEfficientTrainer:
"""
Challenge: Implement memory optimization techniques:
1. Gradient checkpointing (trade compute for memory)
2. In-place operations where safe
3. Clearing intermediate activations
4. Monitoring GPU memory usage
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
from torch.utils.checkpoint import checkpoint
class CheckpointedTransformerBlock(nn.Module):
"""Uses gradient checkpointing to reduce memory at the cost of compute."""
def __init__(self, d_model, num_heads, use_checkpoint=True):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
)
def _attn_block(self, x):
normed = self.norm1(x)
attn_out, _ = self.attn(normed, normed, normed)
return x + attn_out
def _ffn_block(self, x):
return x + self.ffn(self.norm2(x))
def forward(self, x):
if self.use_checkpoint and self.training:
# Gradient checkpointing: don't store activations during forward
# Recompute them during backward -- saves ~60% memory per layer
x = checkpoint(self._attn_block, x, use_reentrant=False)
x = checkpoint(self._ffn_block, x, use_reentrant=False)
else:
x = self._attn_block(x)
x = self._ffn_block(x)
return x
def get_gpu_memory_stats():
"""Monitor GPU memory usage."""
if not torch.cuda.is_available():
return {"allocated": 0, "reserved": 0, "max_allocated": 0}
return {
"allocated_mb": torch.cuda.memory_allocated() / 1024**2,
"reserved_mb": torch.cuda.memory_reserved() / 1024**2,
"max_allocated_mb": torch.cuda.max_memory_allocated() / 1024**2,
}
def memory_efficient_inference(model, data_loader, device):
"""Memory-efficient inference for large datasets."""
model.eval()
all_predictions = []
with torch.no_grad():
for batch in data_loader:
batch = batch.to(device)
preds = model(batch)
# Move predictions to CPU immediately to free GPU memory
all_predictions.append(preds.cpu())
# Delete GPU tensors
del batch, preds
# Force GPU memory cleanup
if torch.cuda.is_available():
torch.cuda.empty_cache()
return torch.cat(all_predictions, dim=0)
# Test
B, S, D, H = 2, 16, 64, 4
block_ckpt = CheckpointedTransformerBlock(D, H, use_checkpoint=True)
block_no_ckpt = CheckpointedTransformerBlock(D, H, use_checkpoint=False)
x = torch.randn(B, S, D, requires_grad=True)
out = block_ckpt(x)
out.sum().backward()
print(f"With checkpointing - output: {out.shape}")
stats = get_gpu_memory_stats()
print(f"GPU memory: {stats}")
Challenge 3: Profiling
Profile a model to find performance bottlenecks — the first step in any optimization effort.
import torch
import torch.nn as nn
import time
from contextlib import contextmanager
@contextmanager
def timer(label=""):
"""Simple timing context manager."""
start = time.perf_counter()
yield
elapsed = time.perf_counter() - start
print(f"{label}: {elapsed*1000:.2f} ms")
class ModelProfiler:
"""
Challenge: Build a profiler that measures:
1. Forward pass time per layer
2. Backward pass time
3. Parameter count per layer
4. Activation memory per layer
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
class ModelProfiler:
def __init__(self, model):
self.model = model
self.layer_times = {}
self.layer_output_shapes = {}
self.hooks = []
def _make_forward_hook(self, name):
def hook(module, input, output):
if hasattr(module, '_start_time'):
elapsed = time.perf_counter() - module._start_time
self.layer_times[name] = elapsed * 1000 # ms
if isinstance(output, torch.Tensor):
self.layer_output_shapes[name] = tuple(output.shape)
return hook
def _make_forward_pre_hook(self, name):
def hook(module, input):
module._start_time = time.perf_counter()
return hook
def register_hooks(self):
"""Register timing hooks on all layers."""
for name, module in self.model.named_modules():
if name == '':
continue
self.hooks.append(module.register_forward_pre_hook(
self._make_forward_pre_hook(name)
))
self.hooks.append(module.register_forward_hook(
self._make_forward_hook(name)
))
def remove_hooks(self):
for hook in self.hooks:
hook.remove()
self.hooks = []
def profile(self, input_tensor, num_runs=10):
"""Run profiling and return results."""
self.register_hooks()
# Warmup
with torch.no_grad():
for _ in range(3):
self.model(input_tensor)
# Profile forward pass
forward_times = []
for _ in range(num_runs):
start = time.perf_counter()
with torch.no_grad():
self.model(input_tensor)
forward_times.append((time.perf_counter() - start) * 1000)
# Profile backward pass
backward_times = []
for _ in range(num_runs):
out = self.model(input_tensor)
start = time.perf_counter()
out.sum().backward()
backward_times.append((time.perf_counter() - start) * 1000)
self.model.zero_grad()
self.remove_hooks()
return {
'forward_ms': sum(forward_times) / num_runs,
'backward_ms': sum(backward_times) / num_runs,
'layer_times': dict(self.layer_times),
'layer_shapes': dict(self.layer_output_shapes),
'param_count': self.count_parameters(),
}
def count_parameters(self):
"""Count parameters per layer."""
counts = {}
total = 0
for name, param in self.model.named_parameters():
counts[name] = param.numel()
total += param.numel()
counts['_total'] = total
return counts
# Test
model = nn.Sequential(
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
x = torch.randn(32, 256)
profiler = ModelProfiler(model)
results = profiler.profile(x, num_runs=5)
print(f"Forward: {results['forward_ms']:.2f} ms")
print(f"Backward: {results['backward_ms']:.2f} ms")
print(f"Total params: {results['param_count']['_total']:,}")
for name, shape in results['layer_shapes'].items():
t = results['layer_times'].get(name, 0)
print(f" {name}: {shape} ({t:.3f} ms)")
Challenge 4: Gradient Checking
Implement numerical gradient checking to verify your custom autograd functions and layers.
import torch
import torch.nn as nn
def numerical_gradient_check(func, inputs, eps=1e-5, atol=1e-4, rtol=1e-3):
"""
Challenge: Implement numerical gradient checking.
- Compute analytical gradients via autograd
- Compute numerical gradients via finite differences
- Compare and report mismatches
- Handle multi-input, multi-output functions
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
def numerical_gradient_check(func, inputs, eps=1e-5, atol=1e-4, rtol=1e-3):
"""
Verify analytical gradients match numerical gradients.
func: callable that takes *inputs and returns a scalar
inputs: list of tensors (will compute gradients for those with requires_grad=True)
"""
# Compute analytical gradients
inputs_clone = [x.clone().detach().requires_grad_(x.requires_grad) for x in inputs]
output = func(*inputs_clone)
output.backward()
results = []
for i, (inp, inp_clone) in enumerate(zip(inputs, inputs_clone)):
if not inp.requires_grad:
continue
analytical_grad = inp_clone.grad.clone()
numerical_grad = torch.zeros_like(inp)
# Compute numerical gradient via central differences
flat_inp = inp.detach().clone().view(-1)
for j in range(flat_inp.numel()):
# f(x + eps)
flat_inp_plus = flat_inp.clone()
flat_inp_plus[j] += eps
inp_plus = flat_inp_plus.view_as(inp)
inputs_plus = [inp_plus if k == i else x.detach() for k, x in enumerate(inputs)]
out_plus = func(*inputs_plus)
# f(x - eps)
flat_inp_minus = flat_inp.clone()
flat_inp_minus[j] -= eps
inp_minus = flat_inp_minus.view_as(inp)
inputs_minus = [inp_minus if k == i else x.detach() for k, x in enumerate(inputs)]
out_minus = func(*inputs_minus)
# Central difference: (f(x+eps) - f(x-eps)) / (2*eps)
numerical_grad.view(-1)[j] = (out_plus.item() - out_minus.item()) / (2 * eps)
# Compare
max_diff = (analytical_grad - numerical_grad).abs().max().item()
rel_diff = max_diff / (analytical_grad.abs().max().item() + 1e-8)
passed = max_diff < atol or rel_diff < rtol
results.append({
'input_idx': i,
'max_abs_diff': max_diff,
'max_rel_diff': rel_diff,
'passed': passed,
})
status = "PASS" if passed else "FAIL"
print(f"Input {i}: {status} (max abs diff: {max_diff:.2e}, max rel diff: {rel_diff:.2e})")
return all(r['passed'] for r in results)
# Test with a custom function
def my_function(x, w):
return (x @ w).pow(2).sum()
x = torch.randn(4, 3, requires_grad=True)
w = torch.randn(3, 5, requires_grad=True)
passed = numerical_gradient_check(my_function, [x, w])
print(f"All gradients correct: {passed}")
Challenge 5: NaN Detection & Recovery
Build a system that detects NaN values during training and helps diagnose the root cause.
import torch
import torch.nn as nn
class NaNDetector:
"""
Challenge: Build a NaN detection system that:
1. Hooks into all layers to detect NaN activations
2. Checks gradients after backward pass
3. Identifies which layer first produces NaN
4. Provides diagnostic information
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
class NaNDetector:
def __init__(self, model):
self.model = model
self.hooks = []
self.nan_detected = False
self.nan_layer = None
self.nan_type = None # 'activation' or 'gradient'
def _check_tensor(self, tensor, name, tensor_type):
if tensor is None:
return False
if isinstance(tensor, tuple):
return any(self._check_tensor(t, f"{name}[{i}]", tensor_type)
for i, t in enumerate(tensor))
if isinstance(tensor, torch.Tensor):
has_nan = torch.isnan(tensor).any().item()
has_inf = torch.isinf(tensor).any().item()
if has_nan or has_inf:
if not self.nan_detected:
self.nan_detected = True
self.nan_layer = name
self.nan_type = tensor_type
issue = "NaN" if has_nan else "Inf"
print(f"[NaN DETECTOR] {issue} detected in {tensor_type} "
f"at layer '{name}'")
print(f" Shape: {tensor.shape}")
print(f" Min: {tensor.min().item():.6f}")
print(f" Max: {tensor.max().item():.6f}")
print(f" NaN count: {torch.isnan(tensor).sum().item()}")
print(f" Inf count: {torch.isinf(tensor).sum().item()}")
return True
return False
def register_hooks(self):
for name, module in self.model.named_modules():
if name == '':
continue
# Forward hook: check activations
def forward_hook(mod, inp, out, name=name):
self._check_tensor(out, name, 'activation')
self.hooks.append(module.register_forward_hook(forward_hook))
# Backward hook: check gradients
def backward_hook(mod, grad_input, grad_output, name=name):
self._check_tensor(grad_output, name, 'gradient (output)')
self._check_tensor(grad_input, name, 'gradient (input)')
self.hooks.append(module.register_full_backward_hook(backward_hook))
def remove_hooks(self):
for h in self.hooks:
h.remove()
self.hooks = []
def check_parameters(self):
"""Check all parameters and their gradients for NaN."""
issues = []
for name, param in self.model.named_parameters():
if torch.isnan(param.data).any():
issues.append(f"NaN in parameter: {name}")
if param.grad is not None and torch.isnan(param.grad).any():
issues.append(f"NaN in gradient: {name}")
return issues
def diagnose(self):
"""Common causes and fixes for NaN."""
print("\n--- NaN Diagnosis Checklist ---")
checks = [
("Learning rate too high", "Try reducing LR by 10x"),
("Missing gradient clipping", "Add torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)"),
("Log of zero or negative", "Add eps: torch.log(x + 1e-8)"),
("Division by zero", "Add eps to denominator: x / (y + 1e-8)"),
("Exploding activations", "Check for missing normalization layers"),
("Unstable loss function", "Use log_softmax instead of log(softmax)"),
("Bad weight initialization", "Use proper init: kaiming, xavier"),
]
for cause, fix in checks:
print(f" - {cause}: {fix}")
# Test
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5),
)
detector = NaNDetector(model)
detector.register_hooks()
# Normal forward pass (no NaN expected)
x = torch.randn(4, 10)
out = model(x)
out.sum().backward()
print(f"NaN detected: {detector.nan_detected}")
# Check parameters
issues = detector.check_parameters()
print(f"Parameter issues: {issues}")
detector.remove_hooks()
# Show diagnosis
detector.diagnose()
Interview insight: When given buggy training code, check these five things in order: (1) Is the model in the right mode (train vs eval)? (2) Are gradients being zeroed? (3) Is data on the correct device? (4) Are intermediate results being detached where needed (.item() for logging)? (5) Is the loss function getting valid inputs (no log of zero, no division by zero)? Verbalizing this systematic approach scores points even if the bug is simple.
Lilly Tech Systems