Training Infrastructure & PipelinesModel Checkpointing & RecoveryHard⏱️ ~3 min

World Size Agnostic Checkpoints and Elastic Recovery

The Brittleness Problem

Traditional checkpoints embed the number of GPUs (world size) and parallelism strategy directly into the checkpoint format by saving tensors with rank specific filenames or shapes. This creates a brittle coupling: if you trained on 256 GPUs with 8 way tensor parallelism and need to resume on 128 GPUs with 4 way parallelism, you face a costly reshape and repartition of the entire checkpoint offline.

World Size Agnostic Design

World size agnostic checkpoints solve this by storing logical parameter names and partition metadata instead of rank indexed state, enabling elastic recovery where you can resume training on a different number of GPUs without rewriting checkpoints. Each parameter tensor is saved with a global logical name and metadata describing how it was partitioned.

Resharding on Restore

On restore, a resharding algorithm reads these logical tensors and repartitions them according to the new world size and parallelism configuration. For example, a weight matrix tensor parallel split 8 ways on 256 GPUs is stored as 8 logical shards. On resume with 4 way tensor parallelism on 128 GPUs, the loader merges pairs of shards into 4 chunks.

Implementation Examples

Meta's FSDP and NVIDIA's Megatron both implement this pattern. FSDP checkpoints include a parameter flattening map that records original tensor shapes and shard boundaries independent of rank count. The restore path can reshape checkpoints across different parallelism configs. Elastic recovery is not free, adding 5 to 10 minutes to RTO for large models, and optimizer state resharding requires careful mapping to ensure numerical consistency.

💡 Key Takeaways
Logical parameter naming decouples checkpoint format from world size: each tensor stored with global name (e.g., transformer.layer.attention.weight) and partition metadata, not rank IDs
Resharding on restore: trained on 256 GPUs with 8 way tensor parallel, resume on 128 GPUs with 4 way by merging and redistributing shards; adds 5 to 10 minutes to Recovery Time Objective (RTO) for TB scale checkpoints
Optimizer state resharding is complex: Adam momentum and variance must align with new parameter partitions to avoid divergence; validation step checks gradient norms match within 0.1% after elastic restore
Meta FSDP and NVIDIA Megatron/NeMo both support elastic checkpoints; Google JAX uses explicit partition specs in checkpoint metadata to enable restore on different device mesh topologies
Trade off: world size agnostic checkpoints add restore complexity and time but enable cost optimization (scale down during idle) and fault tolerance (recover on partial cluster after hardware loss)
📌 Interview Tips
1NVIDIA Megatron 530B model: checkpointed with 8 way tensor parallel, 16 way pipeline parallel on 1024 GPUs; restored to 4 way tensor, 8 way pipeline on 512 GPUs after cluster resize, adding 12 minutes to RTO
2Meta OPT 175B: FSDP checkpoint saved on 256 A100s, restored to 128 A100s for cost reduction during weekend; resharding took 8 minutes, training loss curve matched reference within 0.08% after resume
3Google T5 XXL on TPU v4: checkpoint from 512 TPU cores with 2D mesh [16, 32], restored to 256 cores with mesh [8, 32]; partition spec remapping completed in 6 minutes with automatic shard redistribution
← Back to Model Checkpointing & Recovery Overview
World Size Agnostic Checkpoints and Elastic Recovery | Model Checkpointing & Recovery - System Overflow