Removing the duplicate copies
Plain data parallelism replicates everything on every device, including optimizer state. With large models that wastes huge amounts of memory. The zero redundancy optimizer shards what was duplicated so each device stores only a slice.
- Stage one shards the optimizer state.
- Stage two also shards the gradients.
- Stage three also shards the parameters.
Memory versus communication
Each stage frees more memory but needs more communication to gather the pieces when required. The result is that a model far larger than one device can train without true model parallelism, by treating data parallel devices as a distributed store of state.
- It keeps the programming model of data parallelism.
- More sharding means more gather and scatter traffic.
- It pairs with offloading state to host memory for even larger models.
Sharded state
By never holding redundant copies, the cluster fits models that would otherwise overflow each device.
Key idea
The zero redundancy optimizer shards optimizer state, gradients, and parameters across data parallel devices, trading communication for large memory savings.