Training Infrastructure & Pipelines • Model Checkpointing & RecoveryHard⏱️ ~3 min
World Size Agnostic Checkpoints and Elastic Recovery
Traditional checkpoints embed the number of GPUs (world size) and parallelism strategy directly into the checkpoint format by saving tensors with rank specific filenames or shapes. This creates a brittle coupling: if you trained on 256 GPUs with 8 way tensor parallelism and need to resume on 128 GPUs with 4 way parallelism (perhaps due to hardware availability or cost), you face a costly reshape and repartition of the entire checkpoint offline. World size agnostic checkpoints solve this by storing logical parameter names and partition metadata instead of rank indexed state, enabling elastic recovery where you can resume training on a different number of GPUs or change parallelism degrees without rewriting checkpoints.
The key insight is to decouple the logical model structure from the physical distribution across devices. Each parameter tensor is saved with a global logical name (e.g., "transformer.layers.0.attention.query.weight") and metadata describing how it was partitioned (tensor parallel split along dimension 0, pipeline parallel stage 2). On restore, a resharding algorithm reads these logical tensors and repartitions them according to the new world size and parallelism configuration. For example, a weight matrix of shape [12288, 12288] tensor parallel split 8 ways on 256 GPUs is stored as 8 logical shards of [12288, 1536]. On resume with 4 way tensor parallelism on 128 GPUs, the loader merges pairs of shards into 4 chunks of [12288, 3072].
Meta's Fully Sharded Data Parallel (FSDP) and NVIDIA's Megatron both implement this pattern. FSDP checkpoints include a parameter flattening map that records original tensor shapes and shard boundaries independent of rank count. Megatron saves pipeline and tensor parallel metadata alongside each parameter, and the restore path in NeMo can reshape checkpoints across different parallelism configs by gathering and redistributing slices. Google's JAX and T5X checkpoints use a similar logical naming scheme with explicit partition specs that the restoration code interprets to rebalance across new device meshes.
Elastic recovery is not free. Resharding a multi TB checkpoint on load adds minutes to Recovery Time Objective (RTO). For a 175B parameter model (1.4 TB checkpoint), changing from 256 to 128 GPUs might add 5 to 10 minutes to the restore process as ranks exchange shard slices over the network. Additionally, optimizer state resharding is more complex than parameter resharding because Adam's per parameter momentum and variance must align with the new parameter partitions; incorrect mapping can cause training instability or divergence. Teams validate elastic restores by comparing loss curves and gradients after resume to ensure numerical consistency within acceptable tolerance (typically within 0.1% relative error in gradient norms).
💡 Key Takeaways
•Logical parameter naming decouples checkpoint format from world size: each tensor stored with global name (e.g., transformer.layer.attention.weight) and partition metadata, not rank IDs
•Resharding on restore: trained on 256 GPUs with 8 way tensor parallel, resume on 128 GPUs with 4 way by merging and redistributing shards; adds 5 to 10 minutes to Recovery Time Objective (RTO) for TB scale checkpoints
•Optimizer state resharding is complex: Adam momentum and variance must align with new parameter partitions to avoid divergence; validation step checks gradient norms match within 0.1% after elastic restore
•Meta FSDP and NVIDIA Megatron/NeMo both support elastic checkpoints; Google JAX uses explicit partition specs in checkpoint metadata to enable restore on different device mesh topologies
•Trade off: world size agnostic checkpoints add restore complexity and time but enable cost optimization (scale down during idle) and fault tolerance (recover on partial cluster after hardware loss)
📌 Examples
NVIDIA Megatron 530B model: checkpointed with 8 way tensor parallel, 16 way pipeline parallel on 1024 GPUs; restored to 4 way tensor, 8 way pipeline on 512 GPUs after cluster resize, adding 12 minutes to RTO
Meta OPT 175B: FSDP checkpoint saved on 256 A100s, restored to 128 A100s for cost reduction during weekend; resharding took 8 minutes, training loss curve matched reference within 0.08% after resume
Google T5 XXL on TPU v4: checkpoint from 512 TPU cores with 2D mesh [16, 32], restored to 256 cores with mesh [8, 32]; partition spec remapping completed in 6 minutes with automatic shard redistribution