microgpt
Architecture

Softmax Explained

Converting raw scores into probabilities

Softmax Explained

Softmax is a function that converts raw scores (called logits) into probabilities that sum to 1.

The Problem

Imagine your model outputs raw scores:

Logits: [2.0, 1.0, 0.1, -0.5, 1.5]

These are hard to interpret. We want to know: "What's the probability of each option?"

The Softmax Formula

softmax(x)[i] = exp(x[i]) / sum(exp(x[j]) for all j)

In code:

def softmax(logits):
    max_val = max(val.data for val in logits)
    exps = [(val - max_val).exp() for val in logits]
    total = sum(exps)
    return [e / total for e in exps]

Step by Step

Let's trace through an example:

Step 1: Original logits
logits = [2.0, 1.0, 0.1]

Step 2: Subtract max (for numerical stability)
shifted = [0.0, -1.0, -1.9]

Step 3: Exponential
exp([0.0, -1.0, -1.9]) = [1.0, 0.368, 0.149]

Step 4: Normalize (divide by sum)
sum = 1.517
softmax = [0.659, 0.242, 0.098]

Why Does This Work?

Exponential Amplifies Differences

logits:  [2.0, 1.0, 0.1]
exp:      [7.39, 2.72, 1.11]

The highest logit (2.0) becomes 7.39 - much larger than 1.11! This amplifies the model's "confidence."

Normalization Makes Probabilities

Dividing by the sum ensures:

  • All values are positive
  • All values sum to 1.0

Visual Example

Logits:     [2.0, 1.0, 0.1]
             ▲    ▲    ▲
             │    │    │
             │    │    └──────────┐
             │    └───────────────┐│
             └────────────────────┘│
             │    │    │          │
             ▼    ▼    ▼          ▼
Softmax:  [0.66, 0.24, 0.10]
            ▲    ▲    ▲
            │    │    └─ 10% likely
            │    └────── 24% likely
            └─────────── 66% likely (most likely)

The Temperature Parameter

During inference, there's a "temperature" parameter:

probs = softmax([l / temperature for l in logits])
TemperatureEffect
Low (0.1)Very confident, almost greedy
Medium (0.7)Balanced
High (1.0+)More random, creative

Lower temperature = the model picks the most likely option more often.

Why Not Just Use Argmax?

logits = [2.0, 1.0, 0.1]
argmax = 0  # index of maximum

Argmax gives you the single best answer, but:

  • No sense of confidence
  • Can't sample for diversity
  • Can't use for training (not differentiable)

Softmax gives you a distribution that's:

  • Differentiable (can learn from)
  • Probabilistic (uncertainty matters)
  • Samplable (can generate diverse outputs)

In the Model

Softmax is used:

  1. During training: compute loss from probabilities
  2. During inference: sample the next token

Summary

Softmax converts logits to probabilities:

  1. Exponentiate to make values positive and amplify differences
  2. Normalize so they sum to 1
  3. Temperature controls randomness during generation

This gives us a probability distribution we can learn from and sample from!

On this page