The simplest way to scale
Data parallelism puts a full copy of the model on every device. The training batch is split into shards, one per device, so each worker computes gradients on a different slice of data at the same time.
- Every device holds the same weights.
- Each device sees a different mini batch shard.
- Gradients are averaged so all replicas stay in sync.
Why it works
A gradient over a batch is just the average of per example gradients. So you can compute partial gradients in parallel and average them. After averaging, every replica applies the same update and the weights remain identical across devices.
- It scales nicely when the model fits in one device.
- Throughput grows roughly linearly until communication dominates.
- It does nothing for models too large to fit in memory.
The sync step
The bottleneck is the gradient exchange after each step, which is why collective communication matters so much at scale.
Key idea
Data parallelism replicates the whole model on each device and splits the batch, then averages gradients so all replicas apply identical updates.