Computer Vision Systems • Multi-task LearningEasy⏱️ ~3 min
What is Multi-Task Learning?
Multi task learning trains a single neural network to solve several related tasks simultaneously instead of building separate models for each task. The architecture consists of a shared backbone (also called trunk) that learns common feature representations, topped with small task specific heads that map those shared features to individual task outputs. For example, Google and Meta use multi task models in ad ranking where one shared encoder predicts click through rate (CTR), conversion rate, dwell time, and safety score through four separate heads.
The core benefit is positive transfer across tasks. When predicting both CTR and conversion rate, the shared layers learn user preference representations that improve both objectives, especially when conversion labels are sparse (maybe 1% of impressions) but click labels are abundant (10 to 20% of impressions). This acts as implicit regularization, reducing overfitting on any single task and improving sample efficiency.
In production, multi task learning solves a practical serving problem. Running four separate ranking models might cost 40 to 60 milliseconds total (4 models times 12 milliseconds each). A multi task architecture runs one 12 millisecond shared encoder plus four 1 millisecond heads for 16 milliseconds total. At 100,000 requests per second, this difference translates to hundreds fewer servers and tighter Service Level Objective (SLO) compliance.
The trade off is optimization complexity. Multiple loss functions must be balanced carefully, and gradients from different tasks can conflict. A safety head might push the model away from sensational content while a CTR head pulls toward it. Without proper gradient management, you risk negative transfer where tasks hurt each other instead of helping.
💡 Key Takeaways
•Single shared backbone learns common features, with lightweight task specific heads on top for each objective
•Reduces serving cost by 2 to 3 times compared to separate models, critical at high query per second (QPS) scale
•Improves sample efficiency when some tasks have sparse labels, dense tasks like CTR help sparse tasks like conversion
•Requires careful loss balancing and gradient management to prevent negative transfer where tasks hurt each other
•Used in production by Meta and Google for ad ranking with multiple objectives: CTR, conversion rate (CVR), dwell time, safety scores
📌 Examples
Meta ad ranking: One shared encoder predicts CTR (10% base rate), CVR (1% base rate), dwell time, and safety through four heads, reducing p99 latency from 60ms to 16ms
Tesla autonomous driving: Single vision backbone outputs object detection, depth estimation, lane segmentation, and motion vectors from camera frames within 30ms total
Google keyboard: Shared encoder predicts next word, punctuation, and capitalization within 5 to 20ms per keystroke budget for real time typing