← Lessons

quiz vs the machine

Silver1050

Machine Learning

The Data Parallelism Training

Replicate the model across devices and split the batch to train faster.

4 min read · intro · beat Silver to climb

The simplest way to scale

Data parallelism puts a full copy of the model on every device. The training batch is split into shards, one per device, so each worker computes gradients on a different slice of data at the same time.

  • Every device holds the same weights.
  • Each device sees a different mini batch shard.
  • Gradients are averaged so all replicas stay in sync.

Why it works

A gradient over a batch is just the average of per example gradients. So you can compute partial gradients in parallel and average them. After averaging, every replica applies the same update and the weights remain identical across devices.

  • It scales nicely when the model fits in one device.
  • Throughput grows roughly linearly until communication dominates.
  • It does nothing for models too large to fit in memory.

The sync step

The bottleneck is the gradient exchange after each step, which is why collective communication matters so much at scale.

Key idea

Data parallelism replicates the whole model on each device and splits the batch, then averages gradients so all replicas apply identical updates.

Check yourself

Answer to earn rating on the learn ladder.

1. What does each device hold in data parallel training?

2. Why is averaging gradients correct?