Multi-Head Attention
How transformers learn what to focus on
Multi-Head Attention
Self-attention is the key innovation of transformers. It allows each position in a sequence to "look at" (attend to) every other position.
The Core Idea
When processing a word, the model should consider the context:
"The cat sat on the ___"
↑
What goes here? "mat", "chair", "floor"The model needs to look back at "cat" to predict "mat".
Visual: How Attention Works
Position 0: "The" ─────────────────────┐
Position 1: "cat" ────► attention ─────┼──► "sat"
Position 2: "sat" ─────────────────────┘
Each position looks at ALL previous positions
to understand context!The Attention Formula
The attention mechanism is:
Attention(Q, K, V) = softmax(Q × K^T / √d) × VWhere:
- Q (Query): What I'm looking for
- K (Key): What I offer
- V (Value): What I can share
Visual: Q, K, V Analogy
Imagine a key-value store:
Key (K) Value (V)
┌─────┐ ┌─────┐
│ "a" │──────►│ 1.0 │
│ "b" │──────►│ 2.0 │
│ "c" │──────►│ 3.0 │
└─────┘ └─────┘
↑
Query: "a"
(looking for "a" - gets value 1.0)Breaking Down Attention
Step 1: Compute Q, K, V
q = linear(x, state_dict[f'layer{li}.attn_wq']) # Query
k = linear(x, state_dict[f'layer{li}.attn_wk']) # Key
v = linear(x, state_dict[f'layer{li}.attn_wv']) # ValueEach input position gets its own Q, K, and V vectors.
Step 2: Store Keys and Values
keys[li].append(k)
values[li].append(v)We store all previous keys and values so we can attend to them.
Step 3: Compute Attention Scores
attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]This is a scaled dot-product attention:
- Take dot product of Q and each K
- Divide by √head_dim (scaling)
- This gives a score for each position
Step 4: Softmax
attn_weights = softmax(attn_logits)Convert scores to probabilities (sum to 1).
Step 5: Weighted Sum
head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]Multiply attention weights by values and sum.
Why Multiple Heads?
Having multiple attention heads lets the model attend to different types of information:
| Head | What it learns |
|---|---|
| Head 1 | Grammar/syntax |
| Head 2 | Word relationships |
| Head 3 | Position/ordering |
| Head 4 | Named entities |
Each head learns a different "type" of attention pattern.
The Code in microgpt
for h in range(n_head):
# Get this head's slice
hs = h * head_dim
q_h = q[hs:hs+head_dim]
k_h = [ki[hs:hs+head_dim] for ki in keys[li]]
v_h = [vi[hs:hs+head_dim] for vi in values[li]]
# Attention for this head
attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]
attn_weights = softmax(attn_logits)
head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]
x_attn.extend(head_out)Each head processes a slice (head_dim) of the Q, K, V vectors.
Visual Example
Position 0 ("The"):
Q = "What should I look for?"
Attend to: cat(0.4), sat(0.3), on(0.2), the(0.1)
Position 1 ("cat"):
Q = "What should I look for?"
Attend to: The(0.2), sat(0.2), on(0.2), the(0.2)
The model learns which positions are relevant!Why Scaled Attention?
The division by √head_dim is important:
- If Q and K have large magnitudes, the dot product gets very large
- This pushes softmax into regions with tiny gradients
- Dividing by √d keeps everything in a good range
Causal Attention
In GPT, we only attend to previous positions (causal/masked attention):
Position 0: can only attend to position 0
Position 1: can attend to positions 0, 1
Position 2: can attend to positions 0, 1, 2
...This is why GPT generates left-to-right - each token can only see what came before it.
Summary
Self-attention allows each position to:
- Query: Ask "what should I look for?"
- Attend: Score how relevant each previous position is
- Collect: Gather information weighted by relevance
Multi-head attention runs multiple attention operations in parallel, letting the model learn different types of relationships.
This is the heart of the transformer architecture!