The activation problem
During the backward pass, the optimizer needs the activations computed during the forward pass. Storing every activation for a deep network can use more memory than the weights themselves, which caps how deep or wide a model you can train.
The idea
Gradient checkpointing, also called activation recomputation, saves only a few checkpoint activations during the forward pass and discards the rest. During backward, it recomputes the missing activations on demand by rerunning the forward pass from the nearest checkpoint.
- Forward pass keeps activations only at chosen boundaries.
- Backward pass recomputes the intermediate activations segment by segment.
- Memory drops sharply because most activations are never stored.
The tradeoff
You pay with extra compute. A rough rule is one additional forward pass worth of work, so a step costs around thirty to fifty percent more time. In exchange, activation memory can fall to the square root of the number of layers, which often makes an otherwise impossible model fit.
Key idea
Gradient checkpointing stores only a few activations and recomputes the rest during backward, trading roughly one extra forward pass of compute for a large reduction in activation memory.