← Lessons

quiz vs the machine

Platinum1820

Machine Learning

The Flash Attention Memory

Computing attention without ever materializing the full quadratic score matrix.

6 min read · advanced · beat Platinum to climb

The memory wall in attention

Standard attention forms a score matrix of size sequence length squared, writes it to slow GPU memory, applies softmax, then reads it back. That quadratic memory traffic, not arithmetic, dominates the cost for long sequences.

The flash attention idea

Flash attention computes the same result without ever storing the full score matrix.

  • It processes queries and keys in tiles that fit in fast on chip memory.
  • It uses an online softmax that updates a running maximum and sum as tiles stream by.
  • Each tile output is rescaled and accumulated, so the math stays exact.

Why it is faster

  • The expensive write and read of the giant score matrix is eliminated.
  • The kernel is IO aware, minimizing traffic to slow high bandwidth memory.
  • Memory use becomes linear in sequence length instead of quadratic.

The backward pass

The backward pass recomputes attention tiles rather than storing them, fitting naturally with the tiling scheme and keeping memory low while staying numerically exact.

Key idea

Flash attention tiles queries and keys and uses an online softmax so the full quadratic score matrix is never stored, cutting slow memory traffic and making attention memory linear in length.

Check yourself

Answer to earn rating on the learn ladder.

1. What does flash attention avoid materializing?

2. What technique lets flash attention process tiles without the full matrix?

3. Flash attention is described as IO aware because it primarily reduces