The redundancy ZeRO removes
Plain data parallelism replicates the model, gradients, and optimizer state on every device. For an Adam style optimizer the optimizer state is several times the model size, so most memory is redundant copies.
The three stages
ZeRO partitions that redundant state across the data parallel ranks in increasing stages.
- Stage one shards the optimizer states only.
- Stage two also shards the gradients.
- Stage three further shards the model parameters themselves.
How compute still works
- Each rank owns one shard and gathers the rest on demand for a forward or backward pass.
- After use the gathered pieces are released, so peak memory stays low.
- Communication rises with the stage but stays manageable on good interconnect.
Why it is attractive
- It keeps the simple data parallel programming model, no manual layer splitting.
- Memory per device falls roughly linearly with the number of ranks.
- Offloading shards to CPU or NVMe extends reach further at a speed cost.
Key idea
ZeRO shards the redundant optimizer state, then gradients, then parameters across data parallel ranks in three stages, cutting per device memory while keeping the data parallel model.