microgpt
Architecture

The GPT Forward Pass

Putting it all together - how data flows through the model

The GPT Forward Pass

Now we understand all the components. Let's see how they work together in the gpt() function - the forward pass.

The Big Picture

The forward pass transforms:

  • A single token (at a specific position)
  • Into probability predictions for the next token
Input: token_id + pos_id → [GPT Model] → Output: logits for next token

The gpt() Function

Here's the complete forward pass in microgpt:

def gpt(token_id, pos_id, keys, values):
    # 1. Get embeddings
    tok_emb = state_dict['wte'][token_id]  # Token embedding
    pos_emb = state_dict['wpe'][pos_id]   # Position embedding
    x = [t + p for t, p in zip(tok_emb, pos_emb)]  # Combine

    # 2. Normalize
    x = rmsnorm(x)

    # 3. Process through each layer
    for li in range(n_layer):
        x = transformer_block(x, li, keys, values)

    # 4. Project to vocabulary
    logits = linear(x, state_dict['lm_head'])
    return logits

Let's trace through each step!

Step 1: Get Embeddings

tok_emb = state_dict['wte'][token_id]
pos_emb = state_dict['wpe'][pos_id]
x = [t + p for t, p in zip(tok_emb, pos_emb)]
  1. Look up the token embedding for the current character
  2. Look up the position embedding for the current position
  3. Add them together

This gives us a vector representing "the character at this position."

Step 2: Normalize

x = rmsnorm(x)

Apply RMSNorm to keep the values at a good scale.

Step 3: Process Layers

for li in range(n_layer):
    x = transformer_block(x, li, keys, values)

Run through each transformer layer. A layer contains:

  1. Multi-head attention
  2. MLP (feed-forward network)

We'll look at this in detail next.

Step 4: Project to Vocabulary

logits = linear(x, state_dict['lm_head'])

Take the final hidden state and project it to the vocabulary size. Each number in the output represents the "score" for each possible next character.

The Transformer Block

Here's what happens in each layer:

# 1) Multi-head attention block
x_residual = x
x = rmsnorm(x)
q = linear(x, state_dict[f'layer{li}.attn_wq'])
k = linear(x, state_dict[f'layer{li}.attn_wk'])
v = linear(x, state_dict[f'layer{li}.attn_wv'])
keys[li].append(k)
values[li].append(v)

# Compute attention for each head...
x = linear(x_attn, state_dict[f'layer{li}.attn_wo'])
x = [a + b for a, b in zip(x, x_residual)]  # Residual

# 2) MLP block
x_residual = x
x = rmsnorm(x)
x = linear(x, state_dict[f'layer{li}.mlp_fc1'])
x = [xi.relu() ** 2 for xi in x]
x = linear(x, state_dict[f'layer{li}.mlp_fc2'])
x = [a + b for a, b in zip(x, x_residual)]  # Residual

Visual Summary

Input: character 'e' at position 0


┌─────────────────────────────────────────────┐
│ 1. Add Token + Position Embeddings         │
│    "e" + "position 0" = vector            │
│                                             │
│ 2. RMSNorm                                 │
│                                             │
│ 3. For each layer:                        │
│    ├─ Multi-Head Attention               │
│    │   └─ Each head attends to prev      │
│    ├─ Residual connection                │
│    ├─ MLP                                │
│    └─ Residual connection                 │
│                                             │
│ 4. Project to vocabulary                  │
└─────────────────────────────────────────────┘


Output: logits [score for 'a', score for 'b', ...]

What's Stored in Keys/Values?

The keys and values arrays store the K and V vectors from all previous positions:

keys[li].append(k)   # Store key for this position
values[li].append(v)  # Store value for this position

This allows attention to look at all previous positions when computing the current one.

Why Residuals?

The lines:

x = [a + b for a, b in zip(x, x_residual)]

Add the input back to the output. This is called a residual connection (or skip connection). It helps:

  • Gradients flow through the network during training
  • The network learn identity mappings when helpful

Summary

The GPT forward pass:

  1. Embed the token and position
  2. Normalize the embeddings
  3. Process through each transformer layer:
    • Multi-head attention (with residual)
    • MLP (with residual)
  4. Project to vocabulary for next-token predictions

That's how a single token flows through the network to produce predictions!

On this page