Interactive Demo

Weak-Driven
Learning

How Weak Agents Make Strong Agents Stronger

Based on Chen et al. (2026) · arXiv:2602.08222

Prepared for topic sharing by Goh Jun Xian, Koo Haoming, Low Jia Li Natalie, Joseph Poon, Wong Fang Ting & Yap Zheng Mou
All authors contributed equally. Names listed in alphabetical order.

Post-Training Logit Mixing Gradient Amplification
Scroll to explore the full pipeline

You've probably heard that bigger models and more data make AI better. But what if the training itself is the bottleneck?

This post explains a simple but powerful idea from recent research: a model's own weaker past self can make it smarter - no bigger model, no extra data, no added cost at inference. The technique is called WMSS (Weak-Driven Learning), and it improves math and code benchmarks by +6% over standard training.

Before you scroll, ask yourself:

1

Why does training stall even when a model hasn't fully learned the material? (Hint: it's not overfitting.)

2

If a student aces practice exams by memorising answers, will they pass the real test? What's missing from their learning?

3

Can a weaker model teach a stronger one something useful? Sounds backwards - but it works. This post shows you exactly how and why.

5-minute read · interactive demos · no ML background required
↓ Scroll to start
Background

How Are LLMs Trained?

Three steps from raw text to a working AI assistant.

1

Pre-Training

Learn language from billions of pages of text. Output: a base model that can predict words.

2

Post-Training

Supervised Fine-Tuning (SFT): Train on expert Q&A examples.
RLHF: Optimise with human preference feedback.

3

Deployed Model

Ready to chat, answer questions, and assist. This is ChatGPT, Claude, Gemini.

WMSS focuses on Step 2 - specifically, making Supervised Fine-Tuning more effective when it hits a performance ceiling.
SFT Explained

What is Supervised Fine-Tuning?

Teaching a model to follow instructions using curated examples.

1

Experts Create Data

Human experts write high-quality question-answer pairs. These become the training examples.

2

Train the Model

Model learns to mimic expert answers.

Q: Why is the sky blue?
A: Sunlight scatters off air molecules. Blue light scatters more because of its shorter wavelength.
3

Model Improves

Now follows instructions and gives helpful answers. Ready for the next step: RLHF alignment.

The Problem

The Overconfident Student

To understand why SFT hits a ceiling, think about exam prep...

95%
Prep Exam
75%
Final Exam
S
Student
I scored 95%! I clearly know this material.
No need to keep practising. I'll just wait for the final.
E
Final Exam
Here are questions you've never seen before.
S
Student
I only got 75%?! The final had questions I'd never seen before.
I memorised correct answers, but never learned why the wrong ones were wrong.
LLM Parallel: During SFT, the model reaches 95%+ accuracy on training data. Gradients vanish - training stalls. The "final exam" is the unseen test set. The SFT model memorised training outputs but never learned to distinguish wrong answers. On new data, it fails to generalise.
Key distinction: This is not overfitting. Overfitting means the model learned spurious patterns (noise). Gradient saturation means the model memorised answers but learned too little about why alternatives are wrong.
Overview

The Journey of a Language Model

From raw text to saturation, and how WMSS breaks through.

Pre-training
Learn language
SFT
Learn the task
Saturation
Gradients vanish
WMSS
Weak drives strong
The core idea: Instead of learning from a stronger teacher (knowledge distillation), WMSS uses a weaker historical checkpoint to inject structured uncertainty that reactivates vanishing gradients and drives the strong model to keep improving.
Phase 0

Pre-training: Learning Language

The base model learns to predict the next token from a huge text corpus. Its predictions start nearly uniform across the vocabulary.

Next-Token Prediction

Context: "The derivative of x² is" → predict next token from vocabulary: ["2x", "x", "2", "x²", "dx"]

Base Model Probabilities (nearly uniform)
The base model assigns roughly equal probability to all tokens. It has learned language patterns but hasn't been fine-tuned for any specific task.
Phase 1

Supervised Fine-Tuning & Saturation

SFT sharpens the model's predictions toward the correct answer. But once the model becomes highly confident, gradients vanish and training stalls.

$$\mathcal{L}_{\text{SFT}} = -\log P_\theta(y \mid x) \qquad \frac{\partial \mathcal{L}}{\partial z_k} = P_\theta(k \mid x) \;\; \text{for } k \neq y$$

The gradient on non-target tokens equals their probability - as $P(k|x) \to 0$, the gradient vanishes.

Interactive SFT Training

Target token: "2x" - click Step to train

Probabilities $P(k|x)$
Gradient Magnitude $|\nabla_{z_k}|$
Epoch 0 / 6
P(2x)
Total ∇ (non-target)
Entropy H
The Fix

Learning from Mistakes

What if the student had a study partner who keeps challenging them with tricky wrong answers?

S
Student
New plan: I'll study why the wrong answers are wrong, not just memorise correct ones.
P
Study Partner
Here's a tricky wrong answer - can you explain why it's wrong?
I'm weaker than you, but my confusion highlights what you haven't truly learned.
S
Student
Ah, I see! That answer is wrong because...
Now I truly understand the material.
75%
Before
Review
Studies Mistakes
99%
After
This is exactly what WMSS does! A weaker model acts as the "study partner" - it injects plausible wrong answers, forcing the strong model to keep learning. Gradients stay alive, and performance breaks through the ceiling.
From Analogy to WMSS
The Student
=
The LLM (WMSS)
Memorises right answers (95%)
=
SFT learns correct outputs
Studies why wrong answers are wrong
=
Weak model injects wrong answers
Clear understanding (99%)
=
Breaks through the ceiling!
Context

Three Paradigms for Improving Models

Where does the learning signal come from? Three paradigms - each using a different source and direction of supervision.

Traditional

Knowledge Distillation

Large small

A larger, more capable model teaches a smaller one. Knowledge flows down.

Requires a stronger model. When student catches up, signal vanishes.

Self-Improving

Self-Distillation

Current Past Self

An earlier checkpoint of the same model provides uncertainty signals. No external model needed.

Uses the model's own history as teacher.

WMSS (This Paper)

Weak-Driven Learning

$M_{\text{strong}}$ $M_{\text{weak}}$

An older, weaker checkpoint injects uncertainty upward into the strong model's training signal.

No extra model. No extra inference cost. Just one old checkpoint.

Phase 2

Weak Agent vs Strong Agent

WMSS saves an earlier checkpoint as the Weak Agent and continues training the current model as the Strong Agent. The weak agent retains a softer decision boundary with probability mass on "hard negatives."

$$M_{\text{weak}} \leftarrow M_0 \quad(\text{base checkpoint}), \qquad M_{\text{strong}} \leftarrow M_1 \quad(\text{after SFT})$$
Side-by-Side Comparison
Weak Agent (checkpoint $M_0$)
Entropy H
Strong Agent (after SFT $M_1$)
Entropy H
Notice: the weak agent assigns significant probability to "x" and "2" - these are hard negatives - plausible but incorrect tokens that the strong agent has suppressed. WMSS will use these signals.
Phase 2A

Curriculum-Enhanced Data Activation

Not all training samples are equally useful. WMSS uses entropy dynamics between weak and strong agents to weight samples by three signals.

Understanding Entropy

Entropy measures how uncertain a model is about its prediction. A flat distribution = high uncertainty. A peaked distribution = high confidence.

Low Entropy = Confident
One token dominates - model is sure
High Entropy = Uncertain
Probability spread evenly - model is guessing
$$\Delta H = H(M_{\text{strong}}) - H(M_{\text{weak}})$$
$\Delta H < 0$ - Learned
$M_{\text{strong}}$ is more confident - it learned this concept well.
$\Delta H > 0$ - Regressed
$M_{\text{strong}}$ is less certain - it forgot what was known.
$$p_i \;\propto\; \alpha \cdot H(M_{\text{weak}};\,x_i) \;+\; \beta \cdot [-\Delta H_i]_+ \;+\; \gamma \cdot [\Delta H_i]_+ \qquad \Delta H_i = H(M_{\text{strong}}) - H(M_{\text{weak}})$$
How Samples Get Weighted

Each training sample gets scored by three signals. The model then focuses training on the most informative samples.

α
Base Difficulty

How confused was the weak model? Higher entropy = harder problem = more weight.

β
Consolidation

Did the strong model improve? If yes (ΔH < 0), revisit to stabilize what was learned.

γ
Regression Repair

Did the strong model get worse? If yes (ΔH > 0), up-weight to recover lost ground.

Three Samples, Three Outcomes
EASY - ALREADY MASTERED
"What is 2+2?"
Weak model: uncertain (H=0.8)
Strong model: confident (H=0.2)
ΔH = -0.6 (learned it)
Training weight:
Low - skip, already mastered
REGRESSION - FORGOT
"def fibonacci(n):"
Weak model: uncertain (H=0.8)
Strong model: MORE uncertain (H=1.3)
ΔH = +0.5 (got worse!)
Training weight:
Medium - needs review to recover
GOLDILOCKS - MOST INFORMATIVE
"Solve: ∫sin(x)dx"
Weak model: very uncertain (H=1.8)
Strong model: still uncertain (H=1.6)
ΔH = -0.2 (barely improved)
Training weight:
HIGH - both models struggle, most to learn
Result: The model focuses training on samples where both weak and strong models are similarly uncertain - the "Goldilocks zone" where the most learning can happen. Easy samples get skipped; forgotten samples get reviewed.
Phase 2B - The Core

Joint Training via Logit Mixing

The heart of WMSS: mix the weak agent's logits with the strong agent's logits to create a joint distribution. This reintroduces probability mass on hard negatives, amplifying gradients that had vanished.

$$z_{\text{mix}} = \lambda \cdot z_{\text{strong}} + (1 - \lambda) \cdot z_{\text{weak}} \qquad \mathcal{L}_{\text{mix}} = -\log P_{\text{mix}}(y \mid x)$$
Concrete Example

Predict next token after "The derivative of x² is" - ground truth = 2x

Strong Model (after SFT)
2x: 97%
x: 1.5% · 2: 0.8%
x²: 0.5% · dx: 0.2%
Gradient nearly zero
Weak Model (pre-SFT)
2x: 30%
x: 25% · 2: 22%
x²: 15% · dx: 8%
Still considers alternatives
After Mixing ($\lambda$=0.5)
2x: 64%
x: 13% · 2: 11%
x²: 8% · dx: 4%
Gradients restored!
Interactive Logit Mixing
$\lambda$ 0.70
Strong Logits
Weak Logits
Mixed Logits ($z_{\text{mix}}$)
$P_{\text{strong}}$  vs  $P_{\text{mix}}$ - Probability Comparison
$\nabla_{\text{strong}}$  vs  $\nabla_{\text{mix}}$ - Gradient Magnitudes (non-target tokens)
Total ∇ (strong only)
Total ∇ (mixed)
Amplification
Key insight: As $\lambda$ decreases, the weak agent's softer distribution injects more probability mass onto hard negatives. This directly amplifies gradient magnitudes via $\frac{\partial\mathcal{L}}{\partial z_k} = P_{\text{mix}}(k|x)$, reactivating learning in saturated regions.
Head to Head

SFT vs WMSS: Epoch by Epoch

Token probability distribution across training epochs - 5-token vocabulary, one correct answer. Watch how SFT saturates while WMSS keeps distractors alive for learning.

Interactive Epoch Comparison

Standard SFT

Gradient Alive

WMSS

Same Start
Epoch 1: Both start the same - gradients are alive and the model knows what to fix.
Algorithm 1

The Complete Pipeline

WMSS follows three phases. Watch the algorithm execute step by step.

WMSS Pipeline Animation
Base Model $M_0$
Random init
SFT → $M_1$
Phase 1: Init
Curriculum
Phase 2A: $\Delta H$
Logit Mix
Phase 2B: $z_{\text{mix}}$
$M_{\text{strong}}^+$
Stronger!
Click Play Pipeline to watch WMSS execute step by step.
Why It Works

Mechanistic Insights

WMSS self-regulates through three inherent mechanisms - no hyperparameter tuning required.

Property 1: Saturated-Region Amplification

The softmax function creates an S-curve. At the extremes (very high or very low logits), the curve is flat - gradients vanish. WMSS shifts tokens back to the steep middle region.

100% 50% 0% -4 0 +4 Logit (z) P Saturated Steep Saturated
Why saturation at both ends?
z very negative: P approaches 0%
z very positive: P approaches 100%
At extremes, curve is flat = tiny gradients
WMSS Solution
Mix with weak model logits - shifts tokens back to the steep region where gradients are strong.
Property 2: Gradient Shielding

The weak model's influence naturally fades as the strong model improves. No manual tuning needed.

Early Training
W
S
Weak influence: High
Weak model guides learning
Mid Training
W
S
Weak influence: Fading
Strong model catching up
Late Training
w
S
Weak influence: ~Zero
Strong model takes over
The weak model helps early, then naturally fades. No hyperparameter scheduling needed - it's built into the math.
Property 3: Null-Space Drift

Shifting all scores by the same amount doesn't change which token wins. Only relative gaps matter.

Original Scores
C
B
A
C=1, B=3, A=5 (winner)
Gaps: A-B = 2, A-C = 4
+3
to all
Shifted Scores (+3)
C
B
A
C=4, B=6, A=8 (still winner!)
Gaps: A-B = 2, A-C = 4 (unchanged)
Uniform logit shifts from weak-model mixing are harmless. The decision boundary stays the same because softmax only cares about relative differences, not absolute values.
Comparison

What Each Method Actually Does

How it works - and whether it solves the gradient problem.

Standard SFT

0% gain
Baseline: train on correct tokens only. No escape from the plateau.

NEFTune

+1–2%
Blind regulariser: adds noise to embeddings. Doesn't fix gradient collapse.

UNDIAL

−1.4%
Unlearning method: penalises the correct token. Wrong tool for this job.

WMSS

+5–8%
Targets distractors via logit mixing. Gradient stays alive.

Click a method card above to see how it works.

Results

Breaking the Saturation Ceiling

Concrete results on Qwen3-8B-Base · 2 epochs · averaged over 3 runs.

+6.2%
Math Avg vs SFT
+6.4%
Code Avg vs SFT
3–5x
More gain than NEFTune
Benchmark Results (Qwen3-8B-Base, Table 1)
Method Math Avg Code Avg vs SFT (Math) vs SFT (Code)
SFT baseline 66.7% 71.2%
UNDIAL 67.7% 70.4% +1.0% −0.8%
NEFTune 68.5% 72.4% +1.8% +1.2%
WMSS 72.9% 77.6% +6.2% +6.4%
Ablation Study (Qwen3-4B-Base)

Each bar is an independent experiment vs SFT baseline - not cumulative.

+2.2%
CEDA Only
+1.9%
JTWS Only
+5.8%
Full WMSS
Synergy: CEDA + JTWS individually give +2.2% and +1.9%, but full WMSS gives +5.8% - an extra +1.7% from synergy beyond the sum of parts.
Training Curves: SFT vs WMSS
Standard SFT WMSS (Weak-Driven)
Convergence: Most improvement happens by epoch 4. Beyond that, diminishing returns. Easy tasks (GSM8K at 96%+) stay stable - no catastrophic forgetting.
Limitations

What WMSS Hasn't Proven Yet

Scope boundaries stated directly in the paper. No speculation added.

01

Model Scale

Only Qwen3-4B and 8B tested. Behaviour on 70B+ models is completely unknown.

02

Task Domains

Math reasoning and code generation only. No NLU, translation, or dialogue results reported.

03

Architecture

All experiments use the Qwen3 family. Generalisation to Llama, Mistral, or others not demonstrated.

04

Weak Agent Selection

Always uses $M_0$, the pre-SFT checkpoint. What makes an optimal weak agent remains unexplored.

05

Over-Optimisation

AMC2023 regresses after Epoch 3; GSM8K shows volatility. Epoch 4 risks catastrophic forgetting.

06

Mixing Coefficient $\lambda$

Performance is sensitive to $\lambda$. Peak average at $\lambda$=0.42; drops meaningfully at extremes.

Summary

Where It Fits, What It Proves, What Comes Next

The Problem

SFT saturates once models grow confident. Gradients vanish - more training stops helping.

The only known fixes: bigger models or more data.

The Contribution

  • Mix weak + strong logits ($\lambda$=0.5)
  • Inflates distractor probabilities, keeping gradients alive past saturation
  • Zero extra inference cost
  • +6.2% Math avg, +6.4% Code avg vs SFT
  • Distractor logits suppressed by 56.9%

Open Questions

  • 70B+ models untested
  • Math + code only
  • One architecture (Qwen3)
  • No RLHF comparison
  • Optimal weak checkpoint selection unexplored
"The key to stronger models may lie in understanding their weaker selves."
- Chen et al. (2026)