Natural Language Processing Systems • Scalability (Model Parallelism, Batching)Hard⏱️ ~3 min
What is Fully Sharded Data Parallelism and When Should You Use It?
Fully Sharded Data Parallelism (FSDP) extends data parallelism by sharding not just the data but also the model parameters, gradients, and optimizer states across all workers. While standard data parallelism replicates the entire model on each device and only distributes the input data, FSDP divides everything to drastically reduce per device memory footprint.
In standard data parallelism, each of N workers holds a full copy of a 140 GB model plus gradients and optimizer states, totaling over 420 GB per worker when using Adam with two moment vectors. With FSDP, that same memory is divided across N workers, bringing per worker memory down to roughly 140 GB divided by N plus overhead. For a 405 billion parameter model requiring 7.3 TB unsharded, FSDP can reduce per device memory to around 1.3 TB across workers, an 80 percent reduction that enables training on commodity hardware or increases the degree of parallelism.
The trade is more communication. In a forward pass, each worker only has a shard of parameters, so it must gather the full parameter tensor from peers before computing its local activations. In the backward pass, gradients are computed locally then sharded across workers via reduce scatter operations. Optimizer updates happen on local shards, so each worker updates only its portion of the model. This pattern reduces memory dramatically but increases communication volume compared to standard all reduce of gradients.
FSDP is most valuable when optimizer and gradient memory dominate your footprint, which happens with very large models or when you want to maximize data parallel degree without replicating massive parameter sets. It is less beneficial if your model already fits comfortably on each device, because the added communication overhead outweighs memory savings. Google's use of FSDP style sharding in their large language model training allows scaling to higher data parallelism degrees, and PyTorch FSDP adoption has enabled researchers to train 70 billion parameter models on clusters of 80 GB A100s without prohibitive memory costs.
When implementing FSDP, choose a sharding strategy that balances memory savings against communication cost. Fully shard everything when memory is tight and network bandwidth is strong. Use hybrid sharding within nodes if intra node links are fast but inter node links are slow. Monitor gradient reduce scatter and parameter all gather latencies to ensure they overlap with compute and do not stall the training step.
💡 Key Takeaways
•FSDP shards parameters, gradients, and optimizer states across workers, reducing per device memory by up to 80 percent compared to standard data parallelism where each worker holds a full copy
•For a 405 billion parameter model, FSDP reduces total memory from 7.3 TB unsharded to roughly 1.3 TB sharded across workers, enabling training on more affordable hardware
•Trade off is increased communication because each forward pass requires all gather of parameter shards and each backward requires reduce scatter of gradient shards, versus one all reduce in standard data parallelism
•Most valuable when optimizer and gradient memory dominate, like large language models with Adam optimizer storing two moment vectors per parameter in higher precision
•Hybrid sharding strategies shard across nodes but replicate within nodes to leverage fast intra node NVLink while minimizing slow inter node traffic
📌 Examples
PyTorch FSDP allows training 70 billion parameter models on clusters of 80 GB A100 GPUs by reducing per GPU memory from over 400 GB to under 80 GB
Google uses FSDP style sharding to scale data parallelism degree from 4 to 32 without proportionally increasing total cluster memory requirements
A 140 GB model with standard DP on 8 workers uses 1120 GB total, FSDP on 8 workers uses 140 GB total plus temporary all gather buffers, saving over 900 GB
Training step timeline with FSDP: all gather parameters in 50ms, forward compute 200ms, backward compute 200ms, reduce scatter gradients in 50ms, overlapping communication with compute to hide 80% of latency