RMSNorm
Normalizing activations for stable training
RMSNorm
RMSNorm (Root Mean Square Normalization) is a technique to keep the numbers in a neural network at a reasonable scale. This helps with training stability.
The Problem
As data flows through a neural network, the values can:
- Grow uncontrollably (explode)
- Shrink to nearly zero (vanish)
This makes training difficult. We need to keep values in check.
What Does RMSNorm Do?
RMSNorm normalizes a vector to have consistent "energy":
def rmsnorm(x):
ms = sum(xi * xi for xi in x) / len(x) # Mean square
scale = (ms + 1e-5) ** -0.5 # RMS inverse
return [xi * scale for xi in x] # Apply scalingStep by Step
Given a vector: x = [3.0, 4.0]
Step 1: Compute Mean Square
ms = (3² + 4²) / 2 = (9 + 16) / 2 = 12.5Step 2: Compute Scale Factor
scale = (12.5 + 0.00001) ^ -0.5
= 1 / sqrt(12.5)
≈ 0.283Step 3: Apply Scale
output = [3.0 * 0.283, 4.0 * 0.283]
= [0.85, 1.13]The vector is now normalized!
Why Does This Work?
The RMS (Root Mean Square) is:
RMS(x) = sqrt(mean(x²))This measures the "size" of the vector. Dividing by RMS ensures:
- Output always has RMS of 1
- Output scale stays consistent
The small 1e-5 prevents division by zero.
Visual Comparison
Before RMSNorm:
x = [1000, 2000, 3000] # Huge values!
RMS = 2015
After RMSNorm:
x = [0.37, 0.74, 1.11] # Reasonable values
RMS = 1.0RMSNorm vs LayerNorm
RMSNorm is simpler than LayerNorm:
| Feature | LayerNorm | RMSNorm |
|---|---|---|
| Formula | (x - mean) / sqrt(var + ε) | x / sqrt(mean(x²) + ε) |
| Compute mean | Yes | No |
| Parameters | Mean + Scale | Scale only |
| Speed | Slower | Faster |
Microgpt uses RMSNorm because it's simpler and works just as well!
Where Is RMSNorm Used?
In GPT, RMSNorm is applied:
- After embeddings - normalize the combined embeddings
- Before attention - normalize before computing Q, K, V
- Before MLP - normalize before the feed-forward network
This keeps values stable throughout the network.
The Residual Connection
RMSNorm is often used with residual connections (adding input to output):
x = rmsnorm(x) # Normalize
x = attention_block(x) # Transform
x = x + x_residual # Add back originalThe residual helps gradients flow through the network, while RMSNorm keeps the scale stable.
Summary
RMSNorm normalizes activations:
- Compute the root mean square of the vector
- Scale to maintain consistent magnitude
- Apply before key operations
- Benefits: stable training, faster than LayerNorm
This is one of the key techniques that makes training large neural networks possible!