← Lessons

quiz vs the machine

Silver1060

Machine Learning

Data Parallel Training

Replicate the model across GPUs and split the batch to train faster.

5 min read · intro · beat Silver to climb

What it is

Data parallel training is the most common way to use many GPUs. Each device holds a full copy of the model and the optimizer state. The global batch is split into shards, and every GPU processes a different shard at the same time.

How a step works

  • Each GPU runs a forward pass on its shard and computes a local loss.
  • Each GPU runs a backward pass and gets local gradients.
  • The gradients are averaged across all GPUs so every copy sees the same update.
  • Every GPU applies the averaged gradient, keeping the replicas identical.

The averaging step is the key. Because all replicas start equal and apply the same update, they stay in sync without ever sharing weights directly.

Why it scales

If one GPU handles a batch of 32, eight GPUs can handle 256 per step. With more samples per step you reach a target number of epochs in fewer wall clock minutes. The limit is the communication cost of averaging gradients, which grows with model size and device count.

Key idea

Data parallelism copies the whole model to each device, splits the batch, and averages gradients so all replicas stay identical while training in parallel.

Check yourself

Answer to earn rating on the learn ladder.

1. In data parallel training, what is replicated on every GPU?

2. What keeps the model replicas identical across steps?

3. What usually limits how far data parallelism scales?