Advanced

Transformers & Attention

15 interview questions on the architecture that powers modern AI. Transformers are the single most important topic in DL interviews today — expect at least 2-3 questions on this topic in any ML interview loop.

Q1: Explain scaled dot-product attention. Why do we scale by sqrt(d_k)?

A

Scaled dot-product attention: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V

Q (queries), K (keys), and V (values) are linear projections of the input. QK^T computes similarity between every pair of positions (an n*n attention matrix). The softmax converts similarities to probabilities. The weighted sum with V produces the output.

Why scale by sqrt(d_k)? The dot product of two random vectors with d_k dimensions has variance proportional to d_k. For large d_k (e.g., 64), the dot products become very large, pushing the softmax into regions with extremely small gradients (near 0 or 1). Dividing by sqrt(d_k) keeps the variance at ~1, maintaining healthy gradients through the softmax.

Concrete example: With d_k=64, two random unit vectors have a dot product with std_dev = sqrt(64) = 8. Softmax of values like [-8, 2, 8, -5] produces [0.0000, 0.0025, 0.9975, 0.0000] — nearly one-hot with vanishing gradients. After scaling: [-1, 0.25, 1, -0.625] — much smoother distribution.

Q2: Implement multi-head attention from scratch in PyTorch.

A

This is perhaps the single most common DL coding question. Practice until you can write it without looking at references.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear projections: (batch, seq_len, d_model) -> (batch, seq_len, d_model)
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        # Split into heads: (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention
        # (batch, heads, seq_q, d_k) @ (batch, heads, d_k, seq_k) = (batch, heads, seq_q, seq_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # (batch, heads, seq_q, seq_k) @ (batch, heads, seq_k, d_k) = (batch, heads, seq_q, d_k)
        attn_output = torch.matmul(attn_weights, V)

        # Concat heads: (batch, heads, seq_q, d_k) -> (batch, seq_q, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        return self.W_o(attn_output)

# Test
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 10, 512)  # (batch, seq_len, d_model)
out = mha(x, x, x)  # Self-attention
print(f"Output: {out.shape}")  # [2, 10, 512]

Q3: Why do we use multi-head attention instead of a single attention head with the same total dimensions?

A

Multiple heads allow the model to attend to information from different representation subspaces at different positions simultaneously. A single head with d_model=512 computes one attention pattern. Eight heads with d_k=64 each compute eight different attention patterns.

Empirical evidence: Different heads learn different types of attention patterns. In language models, some heads attend to syntactic dependencies (subject-verb), others to positional patterns (previous/next token), and others to semantic relationships. A single head cannot capture all these patterns simultaneously because softmax forces it to choose one distribution.

Computation cost: Multi-head attention has the same total parameter count and FLOPs as single-head attention with the same d_model. The only overhead is the output projection W_o.

Q4: How does positional encoding work? Compare sinusoidal vs. learned positional embeddings.

A

Why needed: Self-attention is permutation-invariant — it produces the same output regardless of token order. Positional encoding injects position information so the model can distinguish "The cat sat on the mat" from "The mat sat on the cat."

Sinusoidal (original Transformer): PE(pos, 2i) = sin(pos / 10000^(2i/d_model)), PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model)). Advantages: can generalize to longer sequences than seen during training, no learnable parameters. Disadvantage: fixed pattern, not task-specific.

Learned positional embeddings: A learnable embedding table indexed by position (like word embeddings but for positions). Used by BERT, GPT-2. Advantages: can learn task-specific position patterns. Disadvantage: cannot generalize beyond the maximum training length.

RoPE (Rotary Position Embedding): Used by LLaMA, GPT-NeoX. Encodes position by rotating Q and K vectors in 2D subspaces. Key advantage: relative position information is encoded in the dot product, enabling better length generalization.

import torch
import math

class SinusoidalPositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float()
                             * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)  # Even dimensions
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd dimensions

        self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_len, d_model)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        return x + self.pe[:, :x.size(1)]

# Usage
pe = SinusoidalPositionalEncoding(d_model=512)
x = torch.randn(2, 100, 512)
x_with_pos = pe(x)  # Positions are added to embeddings

Q5: What is the causal attention mask and why is it needed in decoder models?

A

A causal (or autoregressive) mask prevents each position from attending to future positions. It is an upper-triangular matrix of -infinity values applied to the attention scores before softmax.

Why needed: During language generation, the model predicts one token at a time. When predicting token 5, it should only see tokens 1-4, not tokens 6-10. Without the mask, the model would "cheat" by looking at future tokens during training, and then fail at generation time when future tokens do not exist.

Encoder vs. decoder: Encoders (BERT) use no causal mask — each token sees all other tokens (bidirectional). Decoders (GPT) use a causal mask — each token only sees previous tokens (unidirectional).

# Causal attention mask
def create_causal_mask(seq_len):
    """Returns mask where True = attend, False = mask out"""
    mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
    return mask

mask = create_causal_mask(5)
# tensor([[ True, False, False, False, False],
#         [ True,  True, False, False, False],
#         [ True,  True,  True, False, False],
#         [ True,  True,  True,  True, False],
#         [ True,  True,  True,  True,  True]])

Q6: Compare BERT and GPT architectures. When would you use each?

A
AspectBERTGPT
ArchitectureEncoder-only TransformerDecoder-only Transformer
AttentionBidirectional (sees all tokens)Causal/unidirectional (sees only past)
Pre-trainingMasked Language Modeling (predict [MASK] tokens) + Next Sentence PredictionAutoregressive LM (predict next token)
Use casesClassification, NER, QA, sentence similarity — tasks with full input availableText generation, code completion, reasoning, chat — autoregressive tasks
Fine-tuningAdd task-specific head, fine-tune all layersIn-context learning (few-shot prompting) or fine-tune with LoRA/full
ScaleBERT-Large: 340M paramsGPT-4: estimated ~1.8T params (MoE)

Key insight: BERT is better for understanding (classification, extraction); GPT is better for generation. Modern large language models are predominantly GPT-style because generation subsumes most understanding tasks via prompting.

Q7: What is the computational complexity of self-attention? Why is it a bottleneck?

A

Time complexity: O(n^2 * d) where n is sequence length and d is the model dimension. The QK^T matrix multiplication produces an n*n attention matrix.

Memory complexity: O(n^2) for storing the attention matrix (per head, per layer).

Why it is a bottleneck: For n=4096, the attention matrix is 4096*4096 = 16.7M entries per head. With 32 heads and 32 layers, that is 17B entries. Doubling the sequence length quadruples the computation and memory.

Efficient attention variants:

  • Flash Attention: Does not reduce the O(n^2) complexity but dramatically reduces memory by computing attention block-by-block without materializing the full n*n matrix. 2-4x faster in practice.
  • Sparse attention (Longformer, BigBird): Each token attends to only a subset of positions (local window + global tokens). O(n * k) where k is the window size.
  • Linear attention (Performer): Approximates softmax attention with kernel functions. O(n * d) complexity but often lower quality.

Q8: What is the KV-cache and why is it essential for efficient autoregressive generation?

A

During autoregressive generation, each new token must attend to all previous tokens. Without caching, generating token n requires recomputing the K and V projections for all n-1 previous tokens — making generation O(n^2) total.

KV-cache: Store the K and V vectors for all previous tokens. When generating token n, only compute the Q for the new token, reuse cached K and V. This makes each step O(n) instead of recomputing everything.

Memory cost: 2 * num_layers * num_heads * seq_len * d_k * bytes_per_element. For a 7B parameter model (32 layers, 32 heads, d_k=128) with fp16 at sequence length 4096: 2 * 32 * 32 * 4096 * 128 * 2 bytes = 2 GB. This is why context length is memory-limited.

Optimizations: Multi-Query Attention (MQA) shares one K,V head across all Q heads, reducing KV-cache by num_heads factor. Grouped-Query Attention (GQA) is a compromise with num_kv_heads groups (e.g., 8 groups for 32 Q heads). Used in LLaMA 2, Mistral.

Q9: What is the feed-forward network in a Transformer block? Why is it important?

A

The FFN in each Transformer block is: FFN(x) = W_2 * activation(W_1 * x + b_1) + b_2. Typically W_1 projects from d_model to 4*d_model, and W_2 projects back to d_model.

Why it matters: Self-attention does weighted averaging, which is a linear operation on the values. The FFN adds non-linearity and performs per-position processing. Without the FFN, stacking attention layers would be limited in expressiveness. Research shows that FFN layers act as key-value memories, storing factual knowledge learned during training.

Modern variants: GLU (Gated Linear Unit) variants like SwiGLU (used in LLaMA) replace the simple FFN: SwiGLU(x) = Swish(W_1 * x) * (W_3 * x), where * is element-wise multiplication. This adds a gating mechanism with a third weight matrix.

Q10: Explain the Vision Transformer (ViT). How does it apply Transformers to images?

A

Key idea: Split an image into fixed-size patches (e.g., 16x16 pixels), flatten each patch into a vector, and treat them as "tokens" in a standard Transformer encoder.

Steps:

  1. Split 224x224 image into 14*14 = 196 patches of 16x16 pixels
  2. Flatten each patch: 16*16*3 = 768-dimensional vector
  3. Linear projection to d_model dimensions
  4. Prepend a learnable [CLS] token
  5. Add learnable positional embeddings
  6. Process through standard Transformer encoder layers
  7. Use the [CLS] token's output for classification

vs. CNNs: ViT needs more data to train (no inductive bias for locality/translation invariance). With enough data or pre-training, ViT matches or exceeds CNNs. Hybrid approaches (CNN stem + Transformer) often work best for medium-sized datasets.

Q11: What are scaling laws? How do they guide model development?

A

Kaplan et al. (2020) / Chinchilla (2022) scaling laws: Model performance (measured as loss) follows power-law relationships with three variables: model size (N parameters), dataset size (D tokens), and compute budget (C FLOPs).

Key finding (Chinchilla): For compute-optimal training, model size and data should scale roughly equally. The previous approach (GPT-3 era) used very large models with relatively little data. Chinchilla showed that a 4x smaller model trained on 4x more data achieves the same loss with the same compute budget.

Practical implications:

  • Loss follows L = a/N^alpha + b/D^beta + c (irreducible loss)
  • You can predict the performance of a 100B model by training several small models and extrapolating the power law
  • If you have a fixed compute budget, there is an optimal split between model size and training data
  • More data is almost always better — most models today are undertrained relative to their size

Q12: What is Flash Attention and why is it faster despite the same O(n^2) complexity?

A

The bottleneck is memory, not compute: Standard attention materializes the full n*n attention matrix in GPU HBM (high bandwidth memory). Reading/writing this large matrix is slower than the actual math. Flash Attention never materializes the full matrix.

How it works: Flash Attention tiles the Q, K, V matrices into small blocks that fit in GPU SRAM (fast on-chip memory). It computes attention block-by-block using online softmax (incrementally computing softmax without seeing all values). Only the final output is written to HBM, not the intermediate attention matrix.

Benefits: 2-4x faster wall-clock time, memory usage goes from O(n^2) to O(n), enables much longer sequences without running out of memory. Used in virtually all modern LLM implementations.

Q13: What is the difference between pre-norm and post-norm Transformer architectures?

A

Post-norm (original Transformer): x + LayerNorm(SelfAttention(x)). Layer normalization is applied after the residual connection. Requires careful learning rate warmup to avoid training instability.

Pre-norm (GPT-2, LLaMA): x + SelfAttention(LayerNorm(x)). Layer normalization is applied before the attention/FFN sublayer. More stable training because the residual path is always a clean addition. This is the default in modern architectures.

Why pre-norm is more stable: In pre-norm, the residual stream maintains a direct path with unit gradient. The sublayer's output is always added to the unchanged input. In post-norm, the sublayer output is mixed before normalization, creating gradient dependencies that can cause instability.

RMSNorm: A simplified version of LayerNorm used in LLaMA: RMSNorm(x) = x / sqrt(mean(x^2) + eps) * gamma. No mean subtraction, ~10% faster, empirically works just as well.

Q14: How does cross-attention differ from self-attention? Where is it used?

A

Self-attention: Q, K, V all come from the same sequence. Token at position i computes attention over all positions in the same sequence. Used in both encoders and decoders.

Cross-attention: Q comes from one sequence (e.g., decoder), K and V come from a different sequence (e.g., encoder output). The decoder queries the encoder's representation to find relevant information.

Where used:

  • Machine translation (original Transformer): Decoder cross-attends to encoder output
  • Text-to-image (Stable Diffusion): Image generation model cross-attends to text encoder output
  • Multimodal models: Text model cross-attends to vision encoder features
  • Retrieval-augmented generation: Language model cross-attends to retrieved documents

Q15: Implement a complete Transformer encoder block in PyTorch.

A

Combines multi-head self-attention, feed-forward network, layer normalization, residual connections, and dropout. Uses pre-norm architecture.

import torch
import torch.nn as nn

class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        # Multi-head self-attention
        self.self_attn = nn.MultiheadAttention(
            d_model, num_heads, dropout=dropout, batch_first=True
        )
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
        # Layer norms (pre-norm architecture)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Pre-norm self-attention with residual
        normed = self.norm1(x)
        attn_out, _ = self.self_attn(normed, normed, normed, attn_mask=mask)
        x = x + self.dropout(attn_out)

        # Pre-norm FFN with residual
        normed = self.norm2(x)
        x = x + self.ffn(normed)

        return x

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, num_layers=6,
                 d_ff=2048, max_len=512, dropout=0.1):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.final_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model

    def forward(self, input_ids, mask=None):
        seq_len = input_ids.size(1)
        positions = torch.arange(seq_len, device=input_ids.device)

        x = self.token_embed(input_ids) * (self.d_model ** 0.5)  # Scale embeddings
        x = x + self.pos_embed(positions)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, mask)

        return self.final_norm(x)

# Test
encoder = TransformerEncoder(vocab_size=30000)
input_ids = torch.randint(0, 30000, (2, 50))
output = encoder(input_ids)
print(f"Output: {output.shape}")  # [2, 50, 512]
print(f"Params: {sum(p.numel() for p in encoder.parameters()):,}")  # ~44M

Key Takeaways

💡
  • Scaling by sqrt(d_k) prevents softmax saturation and maintains healthy gradients
  • Multi-head attention lets the model attend to different patterns simultaneously at no extra cost
  • BERT = encoder (bidirectional, understanding), GPT = decoder (causal, generation)
  • Self-attention is O(n^2) in time and memory; Flash Attention solves the memory problem through tiling
  • KV-cache is essential for efficient generation; GQA/MQA reduce its memory footprint
  • Chinchilla scaling laws: scale model and data equally for compute-optimal training
  • Be able to implement multi-head attention from scratch — it is the most common DL coding question