Self-Attention Mechanism
A comprehensive guide to self-attention mechanism within the context of transformer architecture deep dive.
Understanding Self-Attention
Self-attention is the mechanism that allows each token in a sequence to compute a weighted representation of all other tokens. Unlike cross-attention (where queries come from one sequence and keys/values from another), in self-attention, the queries, keys, and values all come from the same sequence. This enables the model to learn relationships between different positions within a single input.
Consider the sentence "The cat sat on the mat because it was tired." For a human, it is obvious that "it" refers to "the cat." Self-attention enables the model to learn this kind of relationship by computing how strongly each token should attend to every other token.
Queries, Keys, and Values
The self-attention mechanism is built on three learned linear projections of the input:
- Query (Q) — Represents "what am I looking for?" Each token generates a query vector that describes what information it needs from other tokens.
- Key (K) — Represents "what do I contain?" Each token generates a key vector that describes what information it offers to other tokens.
- Value (V) — Represents "what information do I provide?" When a query matches a key, the corresponding value is what gets passed along.
Think of it like a library search. The query is your search term, the keys are the book titles/descriptions, and the values are the actual book contents. The attention score between a query and a key determines how much of that book's content you read.
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.embed_dim = embed_dim
# Learned linear projections
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.scale = embed_dim ** 0.5
def forward(self, x, mask=None):
# x shape: (batch, seq_len, embed_dim)
Q = self.W_q(x) # (batch, seq_len, embed_dim)
K = self.W_k(x) # (batch, seq_len, embed_dim)
V = self.W_v(x) # (batch, seq_len, embed_dim)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# scores shape: (batch, seq_len, seq_len)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = torch.softmax(scores, dim=-1)
output = torch.matmul(weights, V)
return output, weights
The Attention Matrix
The attention matrix (after softmax) is a square matrix of size (seq_len x seq_len) where each row sums to 1. Entry (i, j) represents how much token i attends to token j. This matrix is incredibly informative for understanding what the model has learned.
Attention Patterns
Research has identified common patterns in learned attention matrices:
- Diagonal pattern — Tokens attend primarily to themselves (local self-reference)
- Vertical stripes — All tokens attend to a specific important token (like the CLS token or punctuation)
- Block diagonal — Tokens attend to nearby tokens (local context)
- Coreference pattern — Pronouns attend strongly to their referent nouns
- Syntactic pattern — Tokens attend to syntactically related tokens (subject-verb, modifier-noun)
Computational Complexity
Self-attention computes pairwise interactions between all tokens, resulting in O(n^2) time and space complexity where n is the sequence length:
- Memory — The attention matrix requires storing n^2 values per head per layer
- Computation — Two matrix multiplications of size (n x d) by (d x n)
- Practical limits — Standard transformers struggle with sequences longer than 2048-8192 tokens
Efficient Attention Alternatives
Several approaches reduce the quadratic complexity:
- Sparse attention (BigBird, Longformer) — Only compute attention for a subset of token pairs
- Linear attention (Performer, Linear Transformer) — Approximate softmax attention with linear-time kernel methods
- Flash Attention — Hardware-aware algorithm that reduces memory I/O, not asymptotic complexity
- Sliding window (Mistral) — Each token attends only to a fixed window of nearby tokens
# Flash Attention usage with PyTorch 2.0+
import torch.nn.functional as F
# Standard attention (O(n^2) memory)
attn_output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=mask,
is_causal=True # For decoder-style autoregressive models
)
# PyTorch automatically uses Flash Attention when available
Masked Self-Attention
In autoregressive models (like GPT), each token should only attend to previous tokens, not future ones. This is enforced with a causal mask — an upper-triangular matrix of negative infinity values that prevents information flow from future positions.
In the next lesson, we explore multi-head attention, which allows the model to jointly attend to information from different representation subspaces at different positions.
Lilly Tech Systems