The memory bottleneck
Standard attention computes a full score matrix of size sequence length squared, writes it to GPU memory, applies softmax, then reads it back. For long sequences this matrix is enormous and the slow part is moving data between fast on chip memory and slow main memory, not the math itself.
Tiling and fusion
Flash attention is an IO aware kernel that never materializes the full matrix. It processes attention in blocks that fit in fast on chip memory:
- Load a block of queries, keys, and values into fast memory.
- Compute partial scores and a running softmax for that block.
- Combine block results using an online softmax that updates totals as it goes.
By fusing the steps and keeping data on chip, it avoids the expensive round trips to main memory.
Results
Flash attention gives the exact same output as standard attention, not an approximation. It runs faster and uses memory that grows linearly rather than quadratically with sequence length, enabling much longer contexts.
Key idea
Flash attention computes exact attention in on chip blocks with an online softmax, avoiding the slow memory traffic of materializing the full score matrix.