Computer Vision SystemsMulti-task LearningHard⏱️ ~2 min

Loss Balancing and Gradient Interference

Loss balancing determines how much each task influences the shared backbone during training. The naive approach sums task losses with equal weights, but this fails when loss magnitudes differ by orders of magnitude. A CTR task with binary cross entropy might have loss values around 0.3 to 0.5, while a regression task predicting dwell time in seconds might have mean squared error (MSE) around 100 to 500. The dwell time task will dominate gradient updates, and the CTR task effectively stops learning. Static normalization fixes initial scales by dividing each task loss by its first epoch average, equalizing magnitudes. This is a good starting point. Dynamic methods adapt during training. Uncertainty weighting scales each task by learned noise parameters that represent task confidence. GradNorm adjusts weights to equalize gradient norms across tasks, preventing any single task from dominating. Conflict aware methods like projected gradient descent detect when task gradients point in opposite directions and project out the conflicting components. These techniques reduced training instability in Meta's multi objective ranking models. Gradient interference happens when tasks push shared parameters in opposite directions. Imagine a shared layer learning user embedding. The CTR task wants to increase a weight to capture clickbait signals, while the safety task wants to decrease it to penalize sensational content. Each gradient update partially cancels the other. The result is slow convergence, or worse, oscillation where neither task improves. At scale, this shows as flat or declining validation metrics for minority tasks even as the dominant task improves. Production solutions combine several approaches. Start with static loss normalization. Add per task learning rate modulation based on validation performance, increasing rates for stalled tasks. Use GradNorm or similar dynamic balancing after the first few epochs. Monitor per task gradient norms and cosine similarity between task gradients to detect interference early. For tasks with fundamentally conflicting objectives, consider separate models composed in a downstream policy layer instead of forcing shared optimization.
💡 Key Takeaways
Naive equal weighting fails when task losses differ by 10 to 100 times in magnitude, causing one task to dominate all gradient updates
Static normalization divides each loss by its initial average to equalize scales, providing a strong baseline before dynamic methods
GradNorm dynamically adjusts loss weights to maintain equal gradient norms across tasks, preventing domination and improving convergence speed
Gradient interference occurs when task gradients have negative cosine similarity, causing updates to cancel and stalling learning
Production monitoring must track per task gradient norms, loss curves, and gradient cosine similarity to detect and diagnose interference
📌 Examples
Meta ad ranking: GradNorm improved minority task AUC by 0.3% by preventing CTR task (80% of labels) from dominating conversion task (2% of labels)
Gradient conflict example: CTR gradient [0.5, 0.3, 0.2] and safety gradient [0.1, 0.4, negative 0.6] have conflicting direction on third parameter
Uber demand prediction: Uncertainty weighting learned to downweight noisy cancellation task (30% label noise) relative to clean trip completion task
← Back to Multi-task Learning Overview