Flash Attention Series - Mastering Softmax

April 09, 2026 • 10 min read

Flash Attention is currently one of the most important mainstream attention algorithms. Online softmax sits at the heart of the algorithm, and understanding it will allow you to understand Flash Attention much more smoothly.

PyTorch defines softmax as a function applied to an N-dimensional tensor such that the output tensor is squished between [0, 1] and all the elements sum to 1.

$$ \mathrm{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}} $$

In this post, I'll walk through 4 variations of softmax: a naive version, safe softmax, online softmax, and blocked softmax.

Naive Softmax

Well, if you had to guess just by looking at the formula, what would be the 3 major steps you need to compute softmax?

  1. The numerator is simply the exponential of each element in the tensor.
  2. The denominator is the sum of all exponentials - in other words, the sum of the numerator.
  3. We then divide each numerator element by the denominator.

How easy! Let's implement this by hand without invoking torch.nn.functional.softmax.

import torch

x = torch.tensor([0.2, 0.5, 0.1, -0.5, -0.4])

# step 1. numerator: exp of each element
num = torch.exp(x)

# step 2. denominator: sum of all exp values
denom = torch.sum(torch.exp(x))

# step 3. divide
result = num / denom
print(result)
        

Safe Softmax

Ow, we forgot something important. Try plugging in a ridiculously large number and taking its exponential.

print(torch.exp(torch.tensor(100.0)))
# tensor(inf)
        

This gives you infinity, a floating point overflow. In a transformer, attention logits can reach large magnitudes, so we need to handle this.

What if we subtract the maximum value of the tensor from every element before computing the exponentials? The math still works out perfectly because the max cancels out:

$$ \mathrm{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_{j=1}^{n} e^{x_j - m}}, \quad m = \max_j(x_j) $$

Why does this work? Multiply both the numerator and denominator by \( e^{-m} \):

$$ \frac{e^{x_i}}{\sum_j e^{x_j}} = \frac{e^{-m} \cdot e^{x_i}}{e^{-m} \cdot \sum_j e^{x_j}} = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}} $$

The value of softmax is unchanged. But now the largest value in the exponent is always \( x_{\max} - m = 0 \), so \( e^0 = 1 \) is the biggest term - no more overflow.

import torch

x = torch.tensor([0.2, 0.5, 0.1, -0.5, -0.4, 100.0])

# subtract the max for numerical stability
m = torch.max(x)
shifted_x = x - m

num = torch.exp(shifted_x)
denom = torch.sum(torch.exp(shifted_x))
result = num / denom

print(result)
        

This requires 3 passes over the data

  1. one to find the max
  2. one to compute and sum the exponentials
  3. and one to divide
On a GPU reading from HBM (High Bandwidth Memory), each pass is expensive. Can we do better?

Online Softmax

Now that we understand softmax, we need to think deeply about two problems:

  1. We currently need 3 passes - that's a lot of reads and writes to and from HBM. Can we reduce this to a single pass?

Online softmax solves both. The key idea: we process elements one by one, maintaining a running max \(m\) and a running denominator \(d\). Whenever the running max updates, we rescale the old denominator to correct it. (in a GPU, this happens on the SRAM where the number of cycles needed is extremely less compared to HBM)

The recurrence at each step \(i\) is:

$$ m_i = \max(m_{i-1},\ x_i) $$ $$ d_i = d_{i-1} \cdot e^{m_{i-1} - m_i} + e^{x_i - m_i} $$

The correction factor \( e^{m_{i-1} - m_i} \) rescales all previously accumulated terms to be relative to the new max.

Let's work this out step by step

x = torch.tensor([0.2, 0.5, 0.1, -0.5, -0.4, 100.0])
        

Step 1 x[0] = 0.2

x[0] = 0.2
running_max = 0.2
running_deno = exp(0.2 - 0.2) = exp(0) = 1.0
        

Step 2 x[1] = 0.5, which is a new max:

# new max found: 0.5 > 0.2
old_max = 0.2,  new_max = 0.5
# running_denom = running_denom * torch.exp(old_max - curr_max) + torch.exp(x[i] - curr_max)
running_deno = 1.0 * exp(0.2 - 0.5) + exp(0.5 - 0.5)
  = 1.0 * exp(-0.3)      + exp(0)
  = 1.0 * 0.7408         + 1.0
  = 1.7408

running_max = 0.5
        

Step 3 x[2] = 0.1, no new max

# 0.1 < 0.5, max unchanged
# running_denom += torch.exp(x[i] - curr_max)
running_deno = 1.7408 + exp(0.1 - 0.5)
  = 1.7408 + exp(-0.4)
  = 1.7408 + 0.6703
  = 2.4111
running_max = 0.5
        

Step 4 x[3] = -0.5, no new max

# -0.5 < 0.5, max unchanged
running_deno = 2.4111 + exp(-0.5 - 0.5)
  = 2.4111 + exp(-1.0)
  = 2.4111 + 0.3679
  = 2.7790
running_max = 0.5
        

Step 5 x[4] = -0.4, no new max

# -0.4 < 0.5, max unchanged
running_deno = 2.7790 + exp(-0.4 - 0.5)
  = 2.7790 + exp(-0.9)
  = 2.7790 + 0.4066
  = 3.1856
running_max = 0.5
        

Step 6 x[5] = 100.0, new max

# new max found: 100.0 > 0.5
old_m = 0.5,  new_m = 100.0
running_deno = 3.1856 * exp(0.5 - 100.0) + exp(100.0 - 100.0)
  = 3.1856 * exp(-99.5)        + exp(0)
  = 3.1856 * ~0.0              + 1.0
  ≈ 0.0 + 1.0
  ≈ 1.0              
running_max = 100.0
        

Now you finally

num = torch.exp(x - running_max)   # exp(x - 100) for each element
result = num / running_deno
        

And the full clean implementation

import torch

x = torch.tensor([0.2, 0.5, 0.1, -0.5, -0.4, 100])

curr_max = x[0]
running_denom = torch.exp(x[0] - curr_max)

for i in range(1, x.shape[0]):
    if x[i] > curr_max:
        old_max = curr_max
        curr_max = x[i]
        running_denom = running_denom * torch.exp(old_max - curr_max) + torch.exp(x[i] - curr_max)
    else:
        running_denom += torch.exp(x[i] - curr_max)

num = torch.exp(x - curr_max)

result = num/running_denom
print(result)

## wanna make sure this is right? sum across the row, it should add upto 1
print(f"this should be equal to 1 : {torch.sum(result)}")
        

We've reduced 3 passes over the data to effectively single pass for computing the denominator, with a single final pass for the numerator.

Blocked Softmax

Online softmax processes elements sequentially. But GPUs are parallel machines - we want many threads doing work simultaneously. The idea of blocked softmax is to split the tensor into chunks, compute the max and denominator for each chunk independently, and then merge the results.

Each block's denominator is rescaled to be relative to the global max, then summed. This is the same rescaling trick from online softmax but applied on block level.

import torch

x = torch.tensor([0.2, 0.5, 0.1, -0.5, -0.4, 100.0])

block1, block2 = x.split(3) # splits into 2 blocks

def block_stats(block: torch.Tensor):
    """computes local (max, denominator) for a block using online softmax."""
    m = block[0].item()
    d = torch.exp(block[0] - m) 
    for i in range(1, block.shape[0]):
        xi = block[i].item()
        if xi > m:
            old_m = m
            m = xi
            d = d * torch.exp(torch.tensor(old_m - m)) + torch.exp(block[i] - m)
        else:
            d += torch.exp(block[i] - m)
    return m, d

m1, d1 = block_stats(block1)
m2, d2 = block_stats(block2)

print(f"block1 -> max: {m1:.4f}, denom: {d1:.4f}")
print(f"block2 -> max: {m2:.4f}, denom: {d2:.4f}")

# merge: rescale each block's denom to the global max, then sum
final_m = max(m1, m2)
final_d = d1 * torch.exp(torch.tensor(m1 - final_m)) \
        + d2 * torch.exp(torch.tensor(m2 - final_m))

# final softmax output
result = torch.exp(x - final_m) / final_d

print(result)
print(f"sums to: {result.sum():.6f}")  # 1.0
        

Each block is completely independent. They can run in parallel on different GPU threads or streaming multiprocessors.

In the next post, we'll look at tiling for matrix multiplication and then combine both to build up the full Flash Attention algorithm!