The GPT Forward Pass
Putting it all together - how data flows through the model
The GPT Forward Pass
Now we understand all the components. Let's see how they work together in the gpt() function - the forward pass.
The Big Picture
The forward pass transforms:
- A single token (at a specific position)
- Into probability predictions for the next token
Input: token_id + pos_id → [GPT Model] → Output: logits for next tokenThe gpt() Function
Here's the complete forward pass in microgpt:
def gpt(token_id, pos_id, keys, values):
# 1. Get embeddings
tok_emb = state_dict['wte'][token_id] # Token embedding
pos_emb = state_dict['wpe'][pos_id] # Position embedding
x = [t + p for t, p in zip(tok_emb, pos_emb)] # Combine
# 2. Normalize
x = rmsnorm(x)
# 3. Process through each layer
for li in range(n_layer):
x = transformer_block(x, li, keys, values)
# 4. Project to vocabulary
logits = linear(x, state_dict['lm_head'])
return logitsLet's trace through each step!
Step 1: Get Embeddings
tok_emb = state_dict['wte'][token_id]
pos_emb = state_dict['wpe'][pos_id]
x = [t + p for t, p in zip(tok_emb, pos_emb)]- Look up the token embedding for the current character
- Look up the position embedding for the current position
- Add them together
This gives us a vector representing "the character at this position."
Step 2: Normalize
x = rmsnorm(x)Apply RMSNorm to keep the values at a good scale.
Step 3: Process Layers
for li in range(n_layer):
x = transformer_block(x, li, keys, values)Run through each transformer layer. A layer contains:
- Multi-head attention
- MLP (feed-forward network)
We'll look at this in detail next.
Step 4: Project to Vocabulary
logits = linear(x, state_dict['lm_head'])Take the final hidden state and project it to the vocabulary size. Each number in the output represents the "score" for each possible next character.
The Transformer Block
Here's what happens in each layer:
# 1) Multi-head attention block
x_residual = x
x = rmsnorm(x)
q = linear(x, state_dict[f'layer{li}.attn_wq'])
k = linear(x, state_dict[f'layer{li}.attn_wk'])
v = linear(x, state_dict[f'layer{li}.attn_wv'])
keys[li].append(k)
values[li].append(v)
# Compute attention for each head...
x = linear(x_attn, state_dict[f'layer{li}.attn_wo'])
x = [a + b for a, b in zip(x, x_residual)] # Residual
# 2) MLP block
x_residual = x
x = rmsnorm(x)
x = linear(x, state_dict[f'layer{li}.mlp_fc1'])
x = [xi.relu() ** 2 for xi in x]
x = linear(x, state_dict[f'layer{li}.mlp_fc2'])
x = [a + b for a, b in zip(x, x_residual)] # ResidualVisual Summary
Input: character 'e' at position 0
│
▼
┌─────────────────────────────────────────────┐
│ 1. Add Token + Position Embeddings │
│ "e" + "position 0" = vector │
│ │
│ 2. RMSNorm │
│ │
│ 3. For each layer: │
│ ├─ Multi-Head Attention │
│ │ └─ Each head attends to prev │
│ ├─ Residual connection │
│ ├─ MLP │
│ └─ Residual connection │
│ │
│ 4. Project to vocabulary │
└─────────────────────────────────────────────┘
│
▼
Output: logits [score for 'a', score for 'b', ...]What's Stored in Keys/Values?
The keys and values arrays store the K and V vectors from all previous positions:
keys[li].append(k) # Store key for this position
values[li].append(v) # Store value for this positionThis allows attention to look at all previous positions when computing the current one.
Why Residuals?
The lines:
x = [a + b for a, b in zip(x, x_residual)]Add the input back to the output. This is called a residual connection (or skip connection). It helps:
- Gradients flow through the network during training
- The network learn identity mappings when helpful
Summary
The GPT forward pass:
- Embed the token and position
- Normalize the embeddings
- Process through each transformer layer:
- Multi-head attention (with residual)
- MLP (with residual)
- Project to vocabulary for next-token predictions
That's how a single token flows through the network to produce predictions!