The memory problem
Plain data parallelism copies the full model, gradients, and optimizer state to every GPU. For large models the optimizer state alone can be several times the parameter size, so each device runs out of memory long before compute is the limit.
What FSDP does
Fully Sharded Data Parallel keeps only a shard of each tensor on each GPU. Parameters, gradients, and optimizer state are all partitioned across devices. No single GPU ever holds the whole model at rest.
The gather and release cycle
- Before a layer runs forward, FSDP all gathers that layer's full parameters from all shards.
- The layer computes, then the gathered copy is released to free memory.
- The same gather happens in backward to compute gradients, which are then reduce scattered so each GPU keeps only its shard.
This trades extra communication for a large memory saving. At any moment only one layer is fully materialized, so peak memory stays low even for very large models.
Key idea
FSDP shards parameters, gradients, and optimizer state across GPUs and materializes each layer only when needed, trading communication for the memory to train very large models.