← Lessons

quiz vs the machine

Platinum1800

Machine Learning

The Flash Attention Deep

Computing exact attention tile by tile to slash memory traffic.

6 min read · advanced · beat Platinum to climb

Same math faster

Flash attention computes the exact same softmax attention as before, but reorganizes the computation so it reads and writes far less to slow memory. It is an input output aware algorithm, not an approximation.

The bottleneck it fixes

On a modern accelerator the slow part is moving data between fast on chip memory and large slow memory. The naive method writes the full length by length score matrix to slow memory, then reads it back for softmax. That traffic dominates the time.

Tiling and the online softmax

Flash attention splits queries keys and values into tiles that fit in fast memory. It walks over key tiles, and for each it updates a running softmax using a numerically stable online formula that tracks the running maximum and a running normalizer.

  • Never materializes the full score matrix in slow memory.
  • Keeps results numerically exact through the running statistics.

Backward pass

The backward pass recomputes the needed scores on the fly rather than storing them, trading a little extra compute for a large memory saving. The net effect is faster training and far lower memory, enabling longer contexts.

Key idea

Flash attention computes exact softmax attention in tiles using an online stable softmax so the full score matrix never touches slow memory, cutting memory traffic and enabling faster training and longer contexts.

Check yourself

Answer to earn rating on the learn ladder.

1. Is flash attention an approximation of softmax attention?

2. What does flash attention avoid materializing in slow memory?

3. How does the backward pass save memory?