The memory wall in attention
Standard attention forms a score matrix of size sequence length squared, writes it to slow GPU memory, applies softmax, then reads it back. That quadratic memory traffic, not arithmetic, dominates the cost for long sequences.
The flash attention idea
Flash attention computes the same result without ever storing the full score matrix.
- It processes queries and keys in tiles that fit in fast on chip memory.
- It uses an online softmax that updates a running maximum and sum as tiles stream by.
- Each tile output is rescaled and accumulated, so the math stays exact.
Why it is faster
- The expensive write and read of the giant score matrix is eliminated.
- The kernel is IO aware, minimizing traffic to slow high bandwidth memory.
- Memory use becomes linear in sequence length instead of quadratic.
The backward pass
The backward pass recomputes attention tiles rather than storing them, fitting naturally with the tiling scheme and keeping memory low while staying numerically exact.
Key idea
Flash attention tiles queries and keys and uses an online softmax so the full quadratic score matrix is never stored, cutting slow memory traffic and making attention memory linear in length.