microgpt
Architecture

Multi-Head Attention

How transformers learn what to focus on

Multi-Head Attention

Self-attention is the key innovation of transformers. It allows each position in a sequence to "look at" (attend to) every other position.

The Core Idea

When processing a word, the model should consider the context:

"The cat sat on the ___"

              What goes here? "mat", "chair", "floor"

The model needs to look back at "cat" to predict "mat".

Visual: How Attention Works

Position 0: "The"    ─────────────────────┐
Position 1: "cat"   ────► attention ─────┼──► "sat"
Position 2: "sat"   ─────────────────────┘

Each position looks at ALL previous positions
to understand context!

The Attention Formula

The attention mechanism is:

Attention(Q, K, V) = softmax(Q × K^T / √d) × V

Where:

  • Q (Query): What I'm looking for
  • K (Key): What I offer
  • V (Value): What I can share

Visual: Q, K, V Analogy

Imagine a key-value store:

  Key (K)        Value (V)
  ┌─────┐       ┌─────┐
  │ "a" │──────►│ 1.0 │
  │ "b" │──────►│ 2.0 │
  │ "c" │──────►│ 3.0 │
  └─────┘       └─────┘

        Query: "a"
        (looking for "a" - gets value 1.0)

Breaking Down Attention

Step 1: Compute Q, K, V

q = linear(x, state_dict[f'layer{li}.attn_wq'])  # Query
k = linear(x, state_dict[f'layer{li}.attn_wk'])  # Key
v = linear(x, state_dict[f'layer{li}.attn_wv'])  # Value

Each input position gets its own Q, K, and V vectors.

Step 2: Store Keys and Values

keys[li].append(k)
values[li].append(v)

We store all previous keys and values so we can attend to them.

Step 3: Compute Attention Scores

attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]

This is a scaled dot-product attention:

  1. Take dot product of Q and each K
  2. Divide by √head_dim (scaling)
  3. This gives a score for each position

Step 4: Softmax

attn_weights = softmax(attn_logits)

Convert scores to probabilities (sum to 1).

Step 5: Weighted Sum

head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]

Multiply attention weights by values and sum.

Why Multiple Heads?

Having multiple attention heads lets the model attend to different types of information:

HeadWhat it learns
Head 1Grammar/syntax
Head 2Word relationships
Head 3Position/ordering
Head 4Named entities

Each head learns a different "type" of attention pattern.

The Code in microgpt

for h in range(n_head):
    # Get this head's slice
    hs = h * head_dim
    q_h = q[hs:hs+head_dim]
    k_h = [ki[hs:hs+head_dim] for ki in keys[li]]
    v_h = [vi[hs:hs+head_dim] for vi in values[li]]

    # Attention for this head
    attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]
    attn_weights = softmax(attn_logits)
    head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]

    x_attn.extend(head_out)

Each head processes a slice (head_dim) of the Q, K, V vectors.

Visual Example

Position 0 ("The"):
  Q = "What should I look for?"
  Attend to: cat(0.4), sat(0.3), on(0.2), the(0.1)

Position 1 ("cat"):
  Q = "What should I look for?"
  Attend to: The(0.2), sat(0.2), on(0.2), the(0.2)

The model learns which positions are relevant!

Why Scaled Attention?

The division by √head_dim is important:

  • If Q and K have large magnitudes, the dot product gets very large
  • This pushes softmax into regions with tiny gradients
  • Dividing by √d keeps everything in a good range

Causal Attention

In GPT, we only attend to previous positions (causal/masked attention):

Position 0: can only attend to position 0
Position 1: can attend to positions 0, 1
Position 2: can attend to positions 0, 1, 2
...

This is why GPT generates left-to-right - each token can only see what came before it.

Summary

Self-attention allows each position to:

  1. Query: Ask "what should I look for?"
  2. Attend: Score how relevant each previous position is
  3. Collect: Gather information weighted by relevance

Multi-head attention runs multiple attention operations in parallel, letting the model learn different types of relationships.

This is the heart of the transformer architecture!

On this page