← Lessons

quiz vs the machine

Gold1440

Machine Learning

Fully Sharded Data Parallel

Shard parameters, gradients, and optimizer state to train models that would not otherwise fit.

6 min read · core · beat Gold to climb

The memory problem

Plain data parallelism copies the full model, gradients, and optimizer state to every GPU. For large models the optimizer state alone can be several times the parameter size, so each device runs out of memory long before compute is the limit.

What FSDP does

Fully Sharded Data Parallel keeps only a shard of each tensor on each GPU. Parameters, gradients, and optimizer state are all partitioned across devices. No single GPU ever holds the whole model at rest.

The gather and release cycle

  • Before a layer runs forward, FSDP all gathers that layer's full parameters from all shards.
  • The layer computes, then the gathered copy is released to free memory.
  • The same gather happens in backward to compute gradients, which are then reduce scattered so each GPU keeps only its shard.

This trades extra communication for a large memory saving. At any moment only one layer is fully materialized, so peak memory stays low even for very large models.

Key idea

FSDP shards parameters, gradients, and optimizer state across GPUs and materializes each layer only when needed, trading communication for the memory to train very large models.

Check yourself

Answer to earn rating on the learn ladder.

1. What does FSDP shard across GPUs?

2. How is a layer's full parameters made available for computation?

3. What is the main tradeoff FSDP makes?