Same math faster
Flash attention computes the exact same softmax attention as before, but reorganizes the computation so it reads and writes far less to slow memory. It is an input output aware algorithm, not an approximation.
The bottleneck it fixes
On a modern accelerator the slow part is moving data between fast on chip memory and large slow memory. The naive method writes the full length by length score matrix to slow memory, then reads it back for softmax. That traffic dominates the time.
Tiling and the online softmax
Flash attention splits queries keys and values into tiles that fit in fast memory. It walks over key tiles, and for each it updates a running softmax using a numerically stable online formula that tracks the running maximum and a running normalizer.
- Never materializes the full score matrix in slow memory.
- Keeps results numerically exact through the running statistics.
Backward pass
The backward pass recomputes the needed scores on the fly rather than storing them, trading a little extra compute for a large memory saving. The net effect is faster training and far lower memory, enabling longer contexts.
Key idea
Flash attention computes exact softmax attention in tiles using an online stable softmax so the full score matrix never touches slow memory, cutting memory traffic and enabling faster training and longer contexts.