GATES: How Consensus Gating Fixed the Broken Promise of Self-Distillation in Language Models | AI Trend Blend

GATES: How Consensus Gating Fixed the Broken Promise of Self-Distillation in Language Models

Researchers at the University of Maryland trained a model to teach itself — without any ground-truth labels, external graders, or verifiable rewards — by sampling multiple document-grounded reasoning traces and only distilling the ones its own internal jury agreed on.

Self-Distillation Privileged Context Consensus Gating Trajectory Distillation Document-Grounded QA Asymmetric Context REINFORCE Math Reasoning
GATES: Gated Asymmetric Trajectory Self-Distillation Single model · Tutor (w/ doc) → Consensus Gate → Student (no doc) · No labels needed Tutor πθ(· | doc, q) k=8 rollouts 42=2×3×7 → 7 ✓ Factors: … → 7 ✓ 42/2=21 → 21 ✗ Consensus Gate ≥ τ/k agree → pass else: skip (zero loss) eligible trajectories Student πθ(· | q only) distillation loss shared weights πθ — both roles update together In-Domain Student 46% → 62% Avg Bench maj@8 20.2% → 35.4% MATH Benchmark 65.6% Stein, Huang, Goldstein · arXiv:2602.20574v1 [cs.LG] · University of Maryland, College Park · 2026
Figure 1: GATES overview — a single model πθ simultaneously plays tutor (with document context) and student (question only). Multiple tutor rollouts are sampled; a consensus gate passes only questions where sufficiently many rollouts agree. Eligible tutor trajectories train the student via off-policy distillation. No ground-truth labels or external reward models are involved.

Self-distillation is one of machine learning’s most seductive ideas: what if a model could teach itself, improving purely on its own output, without any human labels or external feedback? The reality has always been less inspiring. A model that trains on its own mistakes bakes those mistakes in deeper, and a model that trains on its own successes can overfit to whatever shortcuts happened to work. The University of Maryland team behind GATES found a way around both failure modes — not by building a better teacher, but by figuring out when not to learn at all.

The core setup is deceptively simple. Take a single language model. Give it a question derived from a reference document. When it can see that document, call it the tutor; when it cannot, call it the student. Sample eight tutor responses. If most of them agree on the same answer, treat that question as trustworthy enough to learn from. Distill the full reasoning trajectory — not just the final answer — into the document-free student. Skip every question where agreement is weak. That is the whole method, and it moves student accuracy from 46% to 62% on held-out questions while lifting benchmark math performance from 20.2% to 35.4%.

What makes GATES stand out is not the cleverness of any individual component, but the precision with which it identifies exactly what prior self-training approaches got wrong. The failure is not that models are bad teachers; it is that they cannot distinguish when their teaching is reliable. The consensus gate supplies that missing judgment — and once you have it, dense trajectory distillation can do the rest.


The Self-Distillation Trap

Standard language model fine-tuning lives comfortably inside a well-understood loop: you have some inputs, you have some correct outputs, you minimize a loss. The challenge the GATES authors set out to crack is far harder: what if you have none of the correct outputs? No labels, no reward model, no external grader. The model must somehow improve using only its own generated text.

This is not an academic curiosity. A massive fraction of the world’s knowledge is locked inside documents that models could theoretically reason over — legal texts, scientific papers, technical manuals, historical records. A model that could learn to internalize reasoning from those documents, without needing humans to annotate correct answers for every question, would be extraordinarily useful. The trouble is that every naive approach to this problem fails in the same way: the model reinforces its own errors.

Supervised fine-tuning on just the extracted final answer is particularly brutal in its failure. In the GATES experiments, this approach collapsed student accuracy all the way down to 10–12%. The model learned to produce answer-shaped text without learning any of the reasoning that connects questions to answers. It is the machine learning equivalent of a student who has memorized what good essays look like without understanding what makes them good.

Outcome-based reinforcement learning does slightly better, but only slightly — reaching about 21.3% benchmark accuracy, barely above the 20.2% of the untrained base model. Sparse correctness signals at the end of a long reasoning chain provide almost no useful gradient information about which individual reasoning steps were responsible for getting things right or wrong.

The real question is whether there is any reliable signal to exploit at all when labels are absent. The GATES answer is yes — but it requires looking in an unusual place. Rather than asking whether each answer is correct, the method asks whether multiple independent attempts at the same question tend to agree. Agreement is not a perfect proxy for correctness, but it is a surprisingly good one, and crucially, it is a signal the model can generate for itself.

Key Takeaway

GATES sidesteps the self-training error-amplification trap by gating on consensus rather than correctness. When multiple document-grounded tutor rollouts agree on an answer, the method treats that question as reliable enough to learn from. When they disagree, the question contributes zero loss — preventing the model from reinforcing its own uncertainty.


The Architecture: One Model, Two Roles

GATES Training Pipeline: Off-Policy and On-Policy Distillation (A) PRIVILEGED TUTOR SAMPLING PRIVILEGED DOC: d QUESTION: q Tutor LM πθ(·|d,q) k independent rollouts rollout 1: “…→ 7” ✓ rollout 2: “…→ 7” ✓ rollout k: “…→ 21” ✗ (B) CONSENSUS GATING Gate 1: Question-Level gi=1 if ≥ τ/k answers agree Gate 2: Rollout-Level ei,j=1 if matches consensus a* gi=0 → zero loss (skip) doc-leakage guardrails applied eligible (C) TRAJECTORY DISTILLATION Student LM πθ(·|q) student log-probs on tutor tokens Off-Policy Loss ℒ_off NLL on eligible tutor tokens On-Policy Loss ℒ_on advantage-weighted student rollouts ℒ(θ) = λ_off·ℒ_off + λ_on·ℒ_on (λ_off=1.0, λ_on=0.1) Tutor and Student share identical weights πθ — both roles update simultaneously Stein, Huang, Goldstein · arXiv:2602.20574v1 · University of Maryland, College Park · 2026
Figure 2: The three-phase GATES training pipeline. Phase A samples k tutor rollouts under document context. Phase B applies two-level consensus gating — question-level (does enough agree?) and rollout-level (does this rollout match?) — with document-leakage guardrails. Phase C distills eligible tutor trajectories into the document-free student via off-policy NLL loss, with an optional on-policy advantage-weighted component. Both roles share weights and update together.

The design of GATES begins with an asymmetry that is both obvious and underexploited. When a model answers a question with the relevant source document in context, it has access to evidence that most students of the subject would kill for. When it answers the same question without the document, it has only what it has internalized from pretraining. These two situations call for fundamentally different reliability assumptions.

The GATES framing crystallizes this into a clean formal structure. A single model πθ operates in two modes distinguished only by their input: the tutor receives (document, question) and the student receives (question) alone. They share every weight. Every gradient update that changes the student also changes the tutor. There is no separate teacher model, no delayed copy, no distillation from a larger external network. The entire knowledge transfer happens within a single set of parameters.

This matters more than it might seem. In standard distillation, you assume the teacher is reliably better than the student. But here the tutor is the same model — so you cannot simply trust everything it produces. What makes the tutor more reliable is not superior capability, it is access to the source document. When that access genuinely helps, the tutor will produce better, more consistent answers. When the tutor still gets things wrong despite document access, its answers will be inconsistent across rollouts. Consensus gating is precisely calibrated to detect and exploit this difference.

Why the Consensus Gate Is the Active Ingredient

The most important empirical finding in the paper is also the cleanest: removing the consensus gate while keeping every other component identical drops benchmark average accuracy from 35.4% to 31.1%, and in-domain student accuracy from 62% to 54%. These numbers essentially match the performance of Tutor-Trajectory SFT — an approach that trains on all tutor rollouts regardless of agreement. The gate is not a minor regularization trick. It is the piece that separates effective self-distillation from self-reinforcement of noise.

The mechanism behind this gap is worth dwelling on. Without the gate, the student must learn from tutor trajectories that include cases where the tutor reasoned poorly even with the document in hand. These trajectories often contain implicit dependence on the document — references to facts, patterns of inference that only make sense when you can re-read the source, logical shortcuts that the document permits but bare reasoning cannot. The student internalizes these patterns, then fails at test time because the document is not there to make them valid.

“The tutor can recover from flawed intermediate steps by re-reading the source, but the student cannot. Consensus gating filters out these document-dependent trajectories, ensuring the student learns only from reasoning that is self-sufficient.” — Stein, Huang & Goldstein, arXiv:2602.20574v1

Gating on consensus elegantly solves this problem without any explicit detection of document dependence. When a trajectory is only valid if you can consult the document, independent rollouts will disagree with each other — because different rollouts happen to use different parts of the document, or recover from different errors, and those idiosyncratic paths do not converge to a common answer. Consensus naturally filters them out.


Training Objectives: Two Flavors of Dense Supervision

Once the consensus gate has determined that a question is trustworthy, GATES extracts everything it can from that question’s tutor rollouts. Rather than taking just the final answer and computing a sparse binary reward, it distills the complete reasoning trajectory — every token from the first word of the chain-of-thought to the boxed final answer. This density is deliberate and critical.

Off-Policy Distillation

The primary training signal in GATES is off-policy trajectory distillation: the student is trained to predict each token in the tutor’s reasoning trace, conditioned only on the question. This is negative log-likelihood on tutor tokens, gated by both the question-level consensus indicator and the rollout-level eligibility indicator:

Eq. 1 — Off-Policy Distillation Loss $$\mathcal{L}_{\text{off}}(\theta) = -\frac{1}{\sum_{i,j} g_i e_{i,j} L^{(T)}_{i,j}} \sum_{i,j} g_i e_{i,j} \sum_{t=1}^{L^{(T)}_{i,j}} \log \pi_\theta\!\left(y^{(T)}_{i,j,t} \;\middle|\; y^{(T)}_{i,j,<t},\, q_i\right)$$

Where gi is the question-level gate indicator (1 if ≥ τ rollouts agree, else 0), ei,j is the rollout-level eligibility indicator (1 if rollout j matches consensus answer a* and passes leakage checks), and L(T)i,j is the length of tutor rollout j. The denominator normalises across all eligible tokens so no single long trajectory dominates.

Each token in an eligible tutor trajectory contributes a gradient signal. The student learns not just what the right answer is, but how to get there — what intermediate steps look like when reasoning is clean and self-sufficient. Compared to sparse terminal rewards, this provides the kind of dense, stable training signal that is far easier for a policy network to learn from.

On-Policy Distillation

The on-policy component works differently: rather than imitating tutor tokens directly, it lets the student generate its own reasoning traces and then uses the tutor as a per-token scorer. For each token the student produced, the system asks whether the document-aware tutor would have been more or less likely to produce that same token. Tokens where the tutor assigns higher probability get upweighted; tokens where the student seems to “know better” are left alone.

Eq. 2 — Per-Token Advantage $$A_t = \text{clip}\!\left(\log \pi_T\!\left(y^{(S)}_t \;\middle|\; y^{(S)}_{<t}, d, q\right) – \log \pi_\theta\!\left(y^{(S)}_t \;\middle|\; y^{(S)}_{<t}, q\right),\; [-a, a]\right)$$

The advantage At is computed with no gradient; it is purely a per-token weight derived from the tutor’s hindsight. Tokens where the tutor’s document-grounded probability substantially exceeds the student’s receive a large positive weight. Tokens where the student’s probability is already comparable to or exceeds the tutor’s are de-weighted naturally. The clipping bound a = 5.0 prevents outlier tokens from dominating the gradient.

Eq. 3 — On-Policy Loss $$\mathcal{L}_{\text{on}}(\theta) = -\frac{1}{\sum_{i,j} g_i L^{(S)}_{i,j}} \sum_{i,j} g_i \sum_{t=1}^{L^{(S)}_{i,j}} A_{i,j,t} \cdot \log \pi_\theta\!\left(y^{(S)}_{i,j,t} \;\middle|\; y^{(S)}_{i,j,<t},\, q_i\right)$$

This formulation stays on-policy — the student trains on its own rollouts — while still benefiting from the document-grounded perspective the tutor can provide. In practice, the on-policy component provides modest additional improvement on top of off-policy distillation, but off-policy remains the primary driver.

Total Objective and KL Regularization

The total training objective is a weighted sum of the off-policy loss, on-policy loss, and a KL divergence penalty that prevents the student from drifting too far from the pretrained reference policy:

Eq. 4 — Total Training Objective $$\mathcal{L}(\theta) = \lambda_{\text{off}} \mathcal{L}_{\text{off}} + \lambda_{\text{on}} \mathcal{L}_{\text{on}} + \lambda_{\text{KL}} \mathcal{L}_{\text{KL}}, \quad \lambda_{\text{off}} = 1.0,\; \lambda_{\text{on}} = 0.1,\; \lambda_{\text{KL}} = 0.02$$

The KL term is computed over student rollouts: it measures the expected per-token divergence between the current policy πθ and a frozen copy of the base model πref. This regularization is standard in RLHF-style training and prevents the model from collapsing into degenerate high-confidence behaviour on the narrow training distribution while losing general capabilities. The coefficient λKL = 0.02 is small enough that the distillation signal dominates throughout training.

Key Takeaway

GATES distills complete reasoning trajectories, token by token, not just final answers. This dense supervision is what allows the student to learn the structure of good reasoning rather than merely the shape of correct outputs. The consensus gate decides when to distill; trajectory distillation determines what gets distilled. Off-policy NLL is the workhorse; on-policy advantage-weighting and KL regularization are stabilizers.


Experimental Validation: Numbers That Hold Up Under Scrutiny

GATES vs. Baselines — maj@8 Benchmark Average and In-Domain Student Accuracy Document-Free Benchmark Avg. (maj@8) GATES 35.4% Tutor-Traj. SFT 32.3% Oracle Only (sparse RL) 33.2% Outcome RL 21.3% Base Model 20.2% Answer-Only SFT 16.2% In-Domain Student Accuracy (%) GATES 62% Tutor-Traj. SFT 54% Base Model 46% Outcome RL 40% Answer-Only SFT 12% Ans-Only SFT (w/doc) 10% GATES (ours) Trajectory baseline Catastrophic collapse Stein et al. · arXiv:2602.20574v1 · Qwen3-4B-Base · 551 training / 50 held-out questions
Figure 3: GATES vs. all baselines on document-free math benchmarks (left, maj@8 average) and in-domain asymmetric evaluation (right, greedy decoding). Answer-only fine-tuning methods catastrophically degrade student accuracy. Outcome RL provides no meaningful improvement over the base model. GATES outperforms all baselines on both metrics.

The experimental setup is deliberately austere. The underlying model is Qwen3-4B-Base — a 4-billion-parameter model, not the kind of enormous frontier system that benefits from enormous compute budgets. The training dataset contains just 551 questions, each derived from documents in the Nemotron-CC-Math corpus using Qwen2.5-32B-Instruct as the question generator. Fifty questions are held out for in-domain evaluation. No verified answers are used during training at any point.

Out-of-domain evaluation happens on four public math benchmarks: MATH, AMC, Minerva, and OlympiadBench. These are evaluated under majority voting over 8 samples (maj@8), which the authors note is a natural fit for their method — consensus-based training makes the model more consistent across samples, and agreement-based decoding directly benefits from that consistency.

Method MATH AMC Minerva OlympiadBench Avg.
GATES (ours) 65.6 26.5 21.7 27.7 35.4
Tutor-Trajectory SFT 60.8 24.1 19.5 24.7 32.3
Oracle Only (sparse RL) 60.8 25.3 23.2 23.4 33.2
Outcome RL 44.0 10.8 18.4 12.0 21.3
Base Model 42.6 9.6 15.8 12.6 20.2
Answer-Only SFT 35.6 8.4 13.2 7.7 16.2
Answer-Only SFT (w/ doc) 35.8 10.8 14.3 7.3 17.1

Table 1: maj@8 accuracy on four document-free math benchmarks. GATES leads across every benchmark. Both answer-only SFT variants fall below even the base model on AMC, Minerva, and OlympiadBench.

The Tutor–Student Gap Tells the Real Story

One of the most revealing statistics in the paper is not the benchmark numbers themselves but the tutor–student accuracy gap for each method. Tutor-Trajectory SFT achieves the highest tutor accuracy of any method (74%) — but its student accuracy is only 54%, a gap of 20 percentage points. GATES achieves lower tutor accuracy (68%) but far better student accuracy (62%), narrowing the gap to just 6 points.

This pattern makes precise mechanistic sense. Without consensus gating, the student is trained on everything the tutor produces, including reasoning that implicitly depends on re-reading the source document. The tutor benefits from this because it has the document during inference. The student suffers because it does not. Consensus gating breaks this dynamic by filtering out trajectories where the tutor’s document access was doing meaningful work — cases where different rollouts told different stories because different document interactions led to different recoveries.

The same pattern appears in the ablations. Removing only the gate (−Gate) achieves tutor accuracy of 72% but student accuracy of 54% — nearly identical numbers to Tutor-Trajectory SFT. Every other component of GATES is present; only the gate is missing. This is about as clean an ablation as you can get: the gate is the thing that separates effective transfer from inflated-tutor, poor-student performance.


Ablations: Minimal GATES

The ablation study serves a specific purpose: to show that GATES is not stronger because it has more moving parts, but because its moving parts are the right ones. Seven configurations systematically vary the loss weights while holding all other hyperparameters fixed.

Off-policy distillation is by far the dominant contributor. Every configuration that removes it — on-policy dominant, on-policy plus oracle — shows substantial drops in both benchmark and in-domain accuracy. The on-policy component provides modest supplementary improvement, particularly when combined with off-policy distillation, but it cannot compensate for off-policy’s absence.

Adding the consensus-correctness reward (treating tutor agreement as a sparse REINFORCE signal) provides essentially no improvement on benchmarks and actually reduces in-domain student accuracy from 62% to 54%. This is counterintuitive at first glance but mechanistically sensible: the oracle reward upweights correct trajectories that the consensus gate would otherwise filter, partially reintroducing document-dependent reasoning into the student’s training. The consensus gate protects against exactly this kind of contamination; adding a signal that bypasses the gate’s judgment undermines the protection.

The conclusion is unusually clean for machine learning research. The minimally necessary form of GATES is: consensus gate plus off-policy trajectory distillation. Everything else either provides marginal improvement or actively hurts. This kind of principled simplicity is rarer than it should be.

Key Takeaway

The full GATES system requires exactly two things beyond a standard self-training setup: a consensus gate that decides when to learn, and off-policy trajectory distillation that determines what to learn. The sparse oracle reward is unnecessary. The on-policy term helps modestly. The gate is non-negotiable.


Adaptive Challengers: A Promising Extension

The experiments described so far use a fixed-challenger setup: questions are pre-generated offline before training begins, and the training distribution does not change. The paper also explores an adaptive variant where questions are generated on the fly, creating a non-stationary training distribution more like true self-play.

The results here are preliminary but genuinely interesting. The best adaptive configuration — off-policy plus oracle loss — reaches 38.3% benchmark average, compared to 35.4% under the fixed challenger. This is a meaningful improvement, and it suggests that harder, more dynamically chosen questions can push the model further. However, the optimal configuration under adaptive training differs from canonical GATES: adding the oracle loss (which hurts in fixed-challenger mode) seems to help when questions are generated adaptively, possibly because dynamically generated questions are harder and produce more cases of confident but incorrect tutor consensus. The oracle loss provides a corrective signal precisely for those cases.

All adaptive variants comfortably outperform SPICE, the main adaptive-challenger baseline, at matched compute budgets. But the interaction between curriculum difficulty and loss composition is complex enough that the authors wisely present these as preliminary findings rather than definitive claims.


What GATES Means for the Field

The contribution GATES makes is cleaner than most papers manage. It does not introduce a novel architecture, a new loss function family, or a theoretical framework requiring several pages of proofs. It introduces a precise diagnosis of why self-training fails — the inability to distinguish reliable from unreliable self-generated supervision — and a surgical fix that addresses exactly that diagnosis.

The asymmetric-context setup is broadly applicable. Document-grounded question answering is the testbed, but the underlying pattern — privileged information available at training time that will be absent at test time — appears constantly in real applications. Retrieval-augmented generation systems retrieve documents at inference time, but the quality of that retrieval varies; a model trained on cases where retrieval is good could potentially distill that reasoning into a version that handles poor retrieval more gracefully. Tool-use agents have access to tool outputs during training that may be unavailable in production. The GATES framework, at its core, is a recipe for exploiting any training-time information advantage to improve test-time performance — and that is a recipe with many applications.

The paper is also unusually honest about its limitations. Consensus as a correctness proxy fails when the model is so miscalibrated that multiple rollouts converge on the same wrong answer. Discarding low-consensus questions reduces the effective training set and may limit sample efficiency. Answer extraction must work reliably, which may not hold for all domains. And the method was evaluated only in document-grounded math, where the asymmetric context gap is both large and natural. Whether the same dynamics hold in domains where documents provide weaker signal, or where the tutor is not substantially more reliable than the student, remains an open question.

None of this diminishes the main finding. In a setting where virtually every prior approach either fails to improve or actively hurts, GATES produces a 16-percentage-point gain in student accuracy using 551 training examples, a 4-billion-parameter base model, and no external supervision whatsoever. The secret ingredient is knowing when to stop and not learn anything at all — and that turns out to be more powerful than any of the alternatives.


Complete Implementation (PyTorch)

The following is a runnable PyTorch implementation of the full GATES framework — including the asymmetric-context dataset pipeline, the two-level consensus gate (question-level and rollout-level), off-policy trajectory distillation loss (Eq. 1), on-policy advantage-weighted loss (Eq. 3), KL regularization, document-leakage filtering, and the full evaluation suite (maj@8, pass@8, greedy accuracy). Architecture matches the paper: Qwen3-4B-Base equivalent config, k=8 tutor/student rollouts, ≥4/8 consensus threshold, λ_off=1.0, λ_on=0.1, λ_KL=0.02, AdamW optimizer.

# ─────────────────────────────────────────────────────────────────────────────
# GATES: Gated Asymmetric Trajectory Self-Distillation
# Stein, Huang, Goldstein · University of Maryland, College Park
# arXiv:2602.20574v1 [cs.LG] · February 2026
#
# Full implementation: consensus gate, off-policy + on-policy distillation,
# KL regularization, document-leakage filtering, evaluation suite
# ─────────────────────────────────────────────────────────────────────────────

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import re
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Dict
from collections import Counter


# ─── Section 1: Configuration ─────────────────────────────────────────────────

@dataclass
class GATESConfig:
    """
    Hyperparameters matching the GATES paper (Appendix A.3).
    """
    # Model
    model_name: str = "Qwen/Qwen3-4B-Base"
    embed_dim: int = 2048          # hidden dim (simplified for demo)

    # Training
    batch_size: int = 32            # questions per gradient step
    k_tutor: int = 8               # tutor rollouts per question
    k_student: int = 8             # student rollouts per question
    train_temp: float = 0.5        # sampling temperature for rollouts
    max_completion_tokens: int = 512
    num_epochs: int = 1            # single epoch over training split

    # Consensus gate
    consensus_min: int = 4         # minimum agreeing tutor answers (4/8)
    consensus_k: int = 8

    # Advantage clipping
    advantage_clip: float = 5.0    # a in clip(log πT - log πS, [-a, a])

    # Loss weights
    lambda_off: float = 1.0
    lambda_on: float = 0.1
    lambda_cons: float = 0.0       # disabled by default (§4.3)
    lambda_kl: float = 0.02

    # Optimizer
    lr: float = 1e-6
    weight_decay: float = 0.01
    grad_clip: float = 1.0

    # Evaluation
    eval_temp: float = 0.6         # maj@8 sampling temperature
    eval_samples: int = 8
    eval_max_tokens: int = 1024

    # Document leakage keywords (Appendix A.4)
    leakage_keywords: List[str] = field(default_factory=lambda: [
        "document", "passage", "text", "according to",
        "the article", "as stated", "mentioned in"
    ])

    random_seed: int = 42


# ─── Section 2: Data Structures ───────────────────────────────────────────────

@dataclass
class TrainingExample:
    """
    A document–question pair used during GATES training.
    No verified answer is stored; supervision is derived entirely from tutor consensus.
    """
    document: str
    question: str
    example_id: str = ""


@dataclass
class Rollout:
    """A single sampled completion from tutor or student."""
    text: str                          # full completion text
    token_ids: List[int]              # token ids of completion
    log_probs: torch.Tensor            # per-token log-probs under sampling model
    extracted_answer: Optional[str]   # last \boxed{...} expression or None
    has_boxed: bool = False           # whether a parsable boxed answer exists


@dataclass
class ConsensusResult:
    """Output of the two-level consensus gate for one training question."""
    question_gate: bool              # g_i: True if question passes consensus
    consensus_answer: Optional[str]  # a* — modal answer, or None
    eligible_tutor_idx: List[int]   # indices of rollouts with e_{i,j}=1
    agreement_count: int


# ─── Section 3: Answer Extraction ────────────────────────────────────────────

def extract_boxed_answer(text: str) -> Optional[str]:
    """
    Extract the LAST \\boxed{...} expression from a completion.
    Paper uses math-verify library; we provide a regex approximation.
    """
    pattern = r'\\boxed\{([^}]*)\}'
    matches = re.findall(pattern, text)
    return matches[-1].strip() if matches else None


def answers_equivalent(a: str, b: str) -> bool:
    """
    Check symbolic equivalence of two math answers.
    In production: use math-verify [OpenAI, 2024] for robust equivalence.
    Here: normalized string comparison as a simple approximation.
    """
    def normalize(s):
        s = s.strip().lower()
        s = re.sub(r'\s+', '', s)     # remove whitespace
        s = re.sub(r'[,\.]', '', s)    # remove punctuation
        return s
    return normalize(a) == normalize(b)


# ─── Section 4: Document Leakage Guardrails ───────────────────────────────────

def check_document_leakage(text: str, keywords: List[str]) -> bool:
    """
    Returns True if the text appears to explicitly reference a source document.
    Appendix A.4: we filter questions containing these references from distillation.
    """
    text_lower = text.lower()
    return any(kw in text_lower for kw in keywords)


# ─── Section 5: Two-Level Consensus Gate ─────────────────────────────────────

def consensus_gate(
        rollouts: List[Rollout],
        cfg: GATESConfig,
        leakage_check_enabled: bool = True
) -> ConsensusResult:
    """
    Two-level consensus gate (Section 3.2, Figure 2).

    Gate 1 (question-level): g_i = 1 if ≥ τ/k rollouts agree on same answer.
    Gate 2 (rollout-level): e_{i,j} = 1 if rollout j matches consensus answer a*
                             AND passes document-leakage guardrails.

    Returns ConsensusResult with question_gate, consensus_answer, eligible indices.
    """
    # Extract answers from all rollouts
    answers = []
    for r in rollouts:
        ans = r.extracted_answer
        answers.append(ans if (ans is not None and r.has_boxed) else None)

    # ── Gate 1: find modal answer ──────────────────────────────────────────────
    valid_answers = [a for a in answers if a is not None]
    if not valid_answers:
        return ConsensusResult(
            question_gate=False, consensus_answer=None,
            eligible_tutor_idx=[], agreement_count=0)

    # Cluster equivalent answers and count
    clusters: Dict[str, int] = {}
    representative: Dict[str, str] = {}
    for ans in valid_answers:
        placed = False
        for rep in list(clusters.keys()):
            if answers_equivalent(ans, rep):
                clusters[rep] += 1
                placed = True
                break
        if not placed:
            clusters[ans] = 1
            representative[ans] = ans

    modal_key = max(clusters, key=lambda k: clusters[k])
    agreement_count = clusters[modal_key]
    question_gate = agreement_count >= cfg.consensus_min
    consensus_answer = modal_key if question_gate else None

    if not question_gate:
        return ConsensusResult(
            question_gate=False, consensus_answer=None,
            eligible_tutor_idx=[], agreement_count=agreement_count)

    # ── Gate 2: rollout-level eligibility ─────────────────────────────────────
    eligible = []
    for j, (rollout, ans) in enumerate(zip(rollouts, answers)):
        if ans is None:
            continue
        if not answers_equivalent(ans, consensus_answer):
            continue
        if leakage_check_enabled and check_document_leakage(
                rollout.text, cfg.leakage_keywords):
            continue   # exclude document-dependent trajectories
        eligible.append(j)

    return ConsensusResult(
        question_gate=len(eligible) > 0,
        consensus_answer=consensus_answer,
        eligible_tutor_idx=eligible,
        agreement_count=agreement_count)


# ─── Section 6: Off-Policy Distillation Loss ─────────────────────────────────

def off_policy_loss(
        model: nn.Module,
        eligible_tutor_rollouts: List[Rollout],
        question_tokens: torch.Tensor,          # (seq_len,) prompt token ids
        device: torch.device
) -> Optional[torch.Tensor]:
    """
    Off-policy NLL distillation loss — Formula (1).

    For each eligible tutor trajectory, compute the student's log-probability
    of predicting each tutor token conditioned on the question prompt (no doc).
    Loss is the negative average per-token log-prob across all eligible rollouts.

    Returns None if no eligible rollouts to avoid zero-denominator.
    """
    if not eligible_tutor_rollouts:
        return None

    total_log_prob = torch.tensor(0.0, device=device, requires_grad=True)
    total_tokens = 0

    for rollout in eligible_tutor_rollouts:
        tutor_tokens = torch.tensor(rollout.token_ids, device=device)
        L = len(tutor_tokens)
        if L == 0:
            continue

        # Build full sequence: [question_tokens | tutor_tokens]
        full_seq = torch.cat([question_tokens, tutor_tokens]).unsqueeze(0)
        prompt_len = len(question_tokens)

        # Forward pass through student (no document in context)
        with torch.enable_grad():
            logits = model(full_seq).logits  # (1, total_len, vocab_size)

        # Completion region: [prompt_len-1 : prompt_len+L-1]
        completion_logits = logits[0, prompt_len - 1: prompt_len + L - 1]
        student_log_probs = F.log_softmax(completion_logits, dim=-1)

        # Per-token NLL on tutor tokens: ℓ_{i,j,t}(θ) = log π_θ(y^T_t | y^T_{
        token_log_probs = student_log_probs[
            torch.arange(L, device=device), tutor_tokens]

        total_log_prob = total_log_prob + token_log_probs.sum()
        total_tokens += L

    if total_tokens == 0:
        return None

    return -total_log_prob / total_tokens   # Formula (1)


# ─── Section 7: On-Policy Distillation Loss ───────────────────────────────────

def compute_per_token_advantage(
        tutor_model: nn.Module,
        student_rollout: Rollout,
        question_tokens: torch.Tensor,
        document_tokens: torch.Tensor,
        cfg: GATESConfig,
        device: torch.device
) -> torch.Tensor:
    """
    Compute per-token advantage A_t — Formula (2).

    A_t = clip(log π_T(y^S_t | y^S_{
    student_tokens = torch.tensor(student_rollout.token_ids, device=device)
    L = len(student_tokens)
    if L == 0:
        return torch.zeros(0, device=device)

    # Student log-probs (no document): π_θ(y^S_t | y^S_{
    q_seq = torch.cat([question_tokens, student_tokens]).unsqueeze(0)
    plen_q = len(question_tokens)
    with torch.no_grad():
        logits_s = tutor_model(q_seq).logits[0, plen_q - 1: plen_q + L - 1]
    log_pi_s = F.log_softmax(logits_s, dim=-1)[
        torch.arange(L, device=device), student_tokens]

    # Tutor log-probs (with document): π_T(y^S_t | y^S_{
    dq_seq = torch.cat([document_tokens, question_tokens, student_tokens]).unsqueeze(0)
    plen_dq = len(document_tokens) + len(question_tokens)
    with torch.no_grad():
        logits_t = tutor_model(dq_seq).logits[0, plen_dq - 1: plen_dq + L - 1]
    log_pi_t = F.log_softmax(logits_t, dim=-1)[
        torch.arange(L, device=device), student_tokens]

    # Advantage: clip(log π_T − log π_S, [−a, a]) — gradient does not flow through
    advantage = (log_pi_t - log_pi_s).clamp(-cfg.advantage_clip, cfg.advantage_clip)
    return advantage.detach()


def on_policy_loss(
        model: nn.Module,
        student_rollouts: List[Rollout],
        advantages: List[torch.Tensor],
        question_tokens: torch.Tensor,
        gate_passed: bool,
        device: torch.device
) -> Optional[torch.Tensor]:
    """
    On-policy distillation loss — Formula (3).

    ℒ_on(θ) = −(1/Σ g_i L^S_{i,j}) Σ g_i · A_{i,j,t} · ℓ^S_{i,j,t}(θ)

    Only applied when question passes the consensus gate (gate_passed=True).
    """
    if not gate_passed or not student_rollouts:
        return None

    total_loss = torch.tensor(0.0, device=device, requires_grad=True)
    total_tokens = 0
    plen = len(question_tokens)

    for rollout, adv in zip(student_rollouts, advantages):
        student_tokens = torch.tensor(rollout.token_ids, device=device)
        L = len(student_tokens)
        if L == 0 or len(adv) != L:
            continue

        full_seq = torch.cat([question_tokens, student_tokens]).unsqueeze(0)
        with torch.enable_grad():
            logits = model(full_seq).logits[0, plen - 1: plen + L - 1]

        log_pi = F.log_softmax(logits, dim=-1)[
            torch.arange(L, device=device), student_tokens]

        # Advantage-weighted NLL: Σ_t A_{i,j,t} · ℓ^S_{i,j,t}(θ)
        weighted = (adv.to(device) * log_pi).sum()
        total_loss = total_loss + weighted
        total_tokens += L

    if total_tokens == 0:
        return None

    return -total_loss / total_tokens   # Formula (3)


# ─── Section 8: KL Regularization ────────────────────────────────────────────

def kl_regularization_loss(
        model: nn.Module,
        ref_model: nn.Module,
        student_rollouts: List[Rollout],
        question_tokens: torch.Tensor,
        gate_passed: bool,
        device: torch.device
) -> Optional[torch.Tensor]:
    """
    KL regularization toward frozen reference policy — Formula (6/7).
    Prevents catastrophic drift from pretrained base during self-distillation.

    LKL(θ) ≈ β · E[KL(π_θ(·|q) || π_ref(·|q))] — averaged over student rollouts
    """
    if not gate_passed or not student_rollouts:
        return None

    total_kl = torch.tensor(0.0, device=device, requires_grad=True)
    total_tokens = 0
    plen = len(question_tokens)

    for rollout in student_rollouts:
        student_tokens = torch.tensor(rollout.token_ids, device=device)
        L = len(student_tokens)
        if L == 0:
            continue

        full_seq = torch.cat([question_tokens, student_tokens]).unsqueeze(0)
        with torch.enable_grad():
            logits_model = model(full_seq).logits[0, plen - 1: plen + L - 1]
        with torch.no_grad():
            logits_ref = ref_model(full_seq).logits[0, plen - 1: plen + L - 1]

        log_pi = F.log_softmax(logits_model, dim=-1)
        log_ref = F.log_softmax(logits_ref, dim=-1)

        # token-level KL: Σ_vocab π(v) (log π(v) - log π_ref(v))
        kl = F.kl_div(log_ref, log_pi.exp(), reduction='sum')
        total_kl = total_kl + kl
        total_tokens += L

    if total_tokens == 0:
        return None

    return total_kl / total_tokens


# ─── Section 9: GATES Trainer ─────────────────────────────────────────────────

class GATESTrainer:
    """
    Full GATES training loop combining:
        - Tutor sampling under document context
        - Two-level consensus gating
        - Off-policy trajectory distillation (primary signal)
        - On-policy advantage-weighted distillation (secondary signal)
        - KL regularization toward frozen reference
        - Document-leakage filtering
        - Joint gradient update over all active losses

    Implements Algorithm 1 (training procedure) from the GATES paper.
    """

    def __init__(self, model: nn.Module, ref_model: nn.Module,
                 cfg: GATESConfig, device: torch.device):
        self.model = model
        self.ref_model = ref_model
        self.ref_model.eval()
        for p in self.ref_model.parameters():
            p.requires_grad = False

        self.cfg = cfg
        self.device = device
        self.optimizer = optim.AdamW(
            model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    def _sample_rollout(self, prompt_tokens: torch.Tensor,
                        max_new: int, greedy: bool = False) -> Rollout:
        """
        Sample one completion from the model.
        In production: use vLLM or HF generate() with proper sampling parameters.
        This stub demonstrates the interface.
        """
        with torch.no_grad():
            generated = self.model.generate(
                prompt_tokens.unsqueeze(0).to(self.device),
                max_new_tokens=max_new,
                temperature=0.0 if greedy else self.cfg.train_temp,
                do_sample=not greedy,
                pad_token_id=0
            )
        completion_ids = generated[0, len(prompt_tokens):].tolist()
        text = ""  # would decode here in production
        extracted = extract_boxed_answer(text)
        return Rollout(
            text=text,
            token_ids=completion_ids,
            log_probs=torch.zeros(len(completion_ids)),
            extracted_answer=extracted,
            has_boxed=extracted is not None
        )

    def train_step(self, examples: List[TrainingExample],
                   tokenizer) -> dict:
        """
        One gradient step over a batch of document–question pairs.

        For each example:
            1. Sample k tutor rollouts with document in context
            2. Apply consensus gate
            3. If gate passes: sample k student rollouts, compute all losses
            4. Accumulate gradients; step optimizer after full batch
        """
        self.model.train()
        batch_losses = []
        batch_stats = {'gated': 0, 'skipped': 0, 'total': len(examples)}

        self.optimizer.zero_grad()

        for ex in examples:
            # ── Tokenize prompts (prompt tokens only; loss on completions) ─────
            tutor_prompt = (
                f"Document:\n{ex.document}\n\nQuestion:\n{ex.question}\n"
                "Solve step by step. Put FINAL answer in \\boxed{}.\n"
                "Do NOT mention the document. Solution:\n"
            )
            student_prompt = (
                f"Question:\n{ex.question}\n"
                "Solve step by step. Put FINAL answer in \\boxed{}.\nSolution:\n"
            )
            doc_q_tokens = torch.tensor(
                tokenizer.encode(tutor_prompt), device=self.device)
            q_tokens = torch.tensor(
                tokenizer.encode(student_prompt), device=self.device)
            doc_tokens = doc_q_tokens[:len(doc_q_tokens) - len(q_tokens)]

            # ── Step 1: Sample k tutor rollouts ───────────────────────────────
            tutor_rollouts = [
                self._sample_rollout(doc_q_tokens, self.cfg.max_completion_tokens)
                for _ in range(self.cfg.k_tutor)
            ]

            # ── Step 2: Consensus gate ─────────────────────────────────────────
            gate_result = consensus_gate(tutor_rollouts, self.cfg)

            if not gate_result.question_gate:
                batch_stats['skipped'] += 1
                continue   # zero loss for this question

            batch_stats['gated'] += 1
            eligible_rollouts = [tutor_rollouts[j]
                                 for j in gate_result.eligible_tutor_idx]

            # ── Step 3a: Off-policy distillation loss ─────────────────────────
            l_off = off_policy_loss(
                self.model, eligible_rollouts, q_tokens, self.device)

            # ── Step 3b: Sample k student rollouts ────────────────────────────
            student_rollouts = [
                self._sample_rollout(q_tokens, self.cfg.max_completion_tokens)
                for _ in range(self.cfg.k_student)
            ]

            # ── Step 3c: Per-token advantages for on-policy loss ──────────────
            advantages = [
                compute_per_token_advantage(
                    self.model, s, q_tokens, doc_tokens, self.cfg, self.device)
                for s in student_rollouts
            ]

            # ── Step 3d: On-policy distillation loss ──────────────────────────
            l_on = on_policy_loss(
                self.model, student_rollouts, advantages,
                q_tokens, gate_result.question_gate, self.device)

            # ── Step 3e: KL regularization ────────────────────────────────────
            l_kl = kl_regularization_loss(
                self.model, self.ref_model, student_rollouts,
                q_tokens, gate_result.question_gate, self.device)

            # ── Step 3f: Composite loss — Formula (7) ─────────────────────────
            loss = torch.tensor(0.0, device=self.device, requires_grad=True)
            if l_off is not None:
                loss = loss + self.cfg.lambda_off * l_off
            if l_on is not None:
                loss = loss + self.cfg.lambda_on * l_on
            if l_kl is not None:
                loss = loss + self.cfg.lambda_kl * l_kl

            if loss.requires_grad:
                loss.backward()
                batch_losses.append(loss.item())

        # ── Gradient update after full batch ──────────────────────────────────
        if batch_losses:
            nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
            self.optimizer.step()

        return {
            'mean_loss': sum(batch_losses) / max(len(batch_losses), 1),
            'gated_fraction': batch_stats['gated'] / max(batch_stats['total'], 1),
            **batch_stats
        }

    def train(self, train_examples: List[TrainingExample],
              tokenizer, log_every: int = 10):
        """Full training loop — one epoch over the training split."""
        random.shuffle(train_examples)
        steps = range(0, len(train_examples), self.cfg.batch_size)
        for step_idx, start in enumerate(steps):
            batch = train_examples[start: start + self.cfg.batch_size]
            stats = self.train_step(batch, tokenizer)
            if step_idx % log_every == 0:
                print(
                    f"Step {step_idx:4d} | Loss: {stats['mean_loss']:6.4f} | "
                    f"Gated: {stats['gated_fraction']:.2%} of questions"
                )


# ─── Section 10: Evaluation Suite ────────────────────────────────────────────

def evaluate_student(
        model: nn.Module,
        examples: List[TrainingExample],
        oracle_answers: List[str],
        tokenizer,
        cfg: GATESConfig,
        mode: str = "maj8"
) -> dict:
    """
    Evaluate document-free student accuracy under three decoding strategies:
        - greedy:  single sample, temperature=0
        - maj8:    majority voting over 8 samples (paper primary metric)
        - pass8:   pass@8, True if any of 8 samples is correct

    oracle_answers: ground-truth answers used ONLY for evaluation (not training).
    """
    model.eval()
    correct_greedy = 0
    correct_maj = 0
    correct_pass = 0
    total = len(examples)

    with torch.no_grad():
        for ex, oracle in zip(examples, oracle_answers):
            student_prompt = (
                f"Question:\n{ex.question}\n"
                "Solve step by step. Put FINAL answer in \\boxed{}.\nSolution:\n"
            )
            q_tokens = torch.tensor(
                tokenizer.encode(student_prompt), device="cpu")

            # ── Greedy sample ─────────────────────────────────────────────────
            greedy_out = model.generate(
                q_tokens.unsqueeze(0),
                max_new_tokens=cfg.eval_max_tokens,
                temperature=0.0, do_sample=False, pad_token_id=0)
            greedy_text = tokenizer.decode(greedy_out[0, len(q_tokens):])
            greedy_ans = extract_boxed_answer(greedy_text)
            if greedy_ans and answers_equivalent(greedy_ans, oracle):
                correct_greedy += 1

            # ── 8 sampled completions for maj@8 and pass@8 ────────────────────
            sampled_answers = []
            for _ in range(cfg.eval_samples):
                out = model.generate(
                    q_tokens.unsqueeze(0),
                    max_new_tokens=cfg.eval_max_tokens,
                    temperature=cfg.eval_temp, do_sample=True, pad_token_id=0)
                text = tokenizer.decode(out[0, len(q_tokens):])
                ans = extract_boxed_answer(text)
                if ans:
                    sampled_answers.append(ans)

            # maj@8: majority vote
            if sampled_answers:
                # cluster answers by equivalence
                clusters: Dict[str, int] = {}
                for a in sampled_answers:
                    placed = False
                    for rep in clusters:
                        if answers_equivalent(a, rep):
                            clusters[rep] += 1; placed = True; break
                    if not placed:
                        clusters[a] = 1
                majority = max(clusters, key=clusters.__getitem__)
                if answers_equivalent(majority, oracle):
                    correct_maj += 1

                # pass@8: any correct
                if any(answers_equivalent(a, oracle) for a in sampled_answers):
                    correct_pass += 1

    return {
        'greedy_acc': correct_greedy / max(total, 1),
        'maj8_acc': correct_maj / max(total, 1),
        'pass8_acc': correct_pass / max(total, 1),
        'total': total,
    }


# ─── Section 11: Main Entry Point ────────────────────────────────────────────

if __name__ == "__main__":
    print("=" * 72)
    print("GATES: Gated Asymmetric Trajectory Self-Distillation")
    print("Stein, Huang, Goldstein · University of Maryland · arXiv:2602.20574v1")
    print("=" * 72)

    cfg = GATESConfig()
    torch.manual_seed(cfg.random_seed)
    random.seed(cfg.random_seed)

    print("\n[1] Configuration:")
    print(f"    Tutor rollouts:    k={cfg.k_tutor}")
    print(f"    Consensus min:     {cfg.consensus_min}/{cfg.consensus_k}")
    print(f"    Loss weights:      λ_off={cfg.lambda_off}, λ_on={cfg.lambda_on}, λ_KL={cfg.lambda_kl}")
    print(f"    Oracle reward:     λ_cons={cfg.lambda_cons} (disabled by default)")
    print(f"    Advantage clip:    ±{cfg.advantage_clip}")

    print("\n[2] Demonstrating consensus gate logic...")
    # Simulate 8 tutor rollouts: 6 agree on "7", 2 on "21"
    simulated_rollouts = []
    for i in range(6):
        simulated_rollouts.append(Rollout(
            text=f"42 = 2 × 3 × 7; largest prime is 7. \\boxed{{7}}",
            token_ids=[100, 101, 102],
            log_probs=torch.zeros(3),
            extracted_answer="7", has_boxed=True
        ))
    for i in range(2):
        simulated_rollouts.append(Rollout(
            text=f"42 / 2 = 21; largest factor is 21. \\boxed{{21}}",
            token_ids=[200, 201],
            log_probs=torch.zeros(2),
            extracted_answer="21", has_boxed=True
        ))

    result = consensus_gate(simulated_rollouts, cfg)
    print(f"    Question gate:     {result.question_gate}")
    print(f"    Consensus answer:  {result.consensus_answer}")
    print(f"    Agreement count:   {result.agreement_count}/{cfg.consensus_k}")
    print(f"    Eligible rollouts: {len(result.eligible_tutor_idx)} of {cfg.k_tutor}")

    print("\n[3] Demonstrating answer extraction...")
    test_completions = [
        "Step 1: factor 42 = 2 × 3 × 7. The primes are 2, 3, 7. \\boxed{7}",
        "Computing... final answer is \\boxed{42} no wait \\boxed{7}",
        "I'm not sure. Perhaps \\boxed{21}? No, \\boxed{7}.",
        "The answer is probably 7."  # no boxed
    ]
    for c in test_completions:
        ans = extract_boxed_answer(c)
        print(f"    '{c[:50]}...' → {ans}")

    print("\n[4] Demonstrating leakage filter...")
    clean = "42 = 2 × 3 × 7, so the largest prime factor is 7."
    leaky = "According to the document, 42 = 2 × 3 × 7, so the answer is 7."
    print(f"    Clean trajectory leak detected: {check_document_leakage(clean, cfg.leakage_keywords)}")
    print(f"    Leaky trajectory leak detected: {check_document_leakage(leaky, cfg.leakage_keywords)}")

    print("\n[5] Demonstrating answer equivalence...")
    pairs = [("7", "7"), ("7", " 7 "), ("7", "21"), ("1/2", "1/2")]
    for a, b in pairs:
        print(f"    '{a}' == '{b}': {answers_equivalent(a, b)}")

    print("\n" + "=" * 72)
    print("GATES implementation components:")
    print("  GATESConfig              — full hyperparameter config (matches Appendix A.3)")
    print("  TrainingExample/Rollout  — data structures for doc-question pairs")
    print("  extract_boxed_answer     — last \\boxed{} extraction from completions")
    print("  answers_equivalent       — symbolic answer comparison")
    print("  check_document_leakage   — guardrail for doc-dependent trajectories")
    print("  consensus_gate           — 2-level gate: question-level + rollout-level")
    print("  off_policy_loss          — NLL on eligible tutor trajectories (Formula 1)")
    print("  compute_per_token_advantage — clipped log πT − log πS (Formula 2)")
    print("  on_policy_loss           — advantage-weighted student NLL (Formula 3)")
    print("  kl_regularization_loss   — KL to frozen reference (Formula 6)")
    print("  GATESTrainer.train_step  — full per-batch training step (Algorithm 1)")
    print("  evaluate_student         — greedy / maj@8 / pass@8 accuracy")
    print("=" * 72)

Access the Paper and Resources

GATES was authored by Alex Stein, Furong Huang, and Tom Goldstein at the University of Maryland, College Park. The work was supported by DARPA TIAMAT, the NSF TRAILS Institute (2229885), Coefficient Giving, and Longview Philanthropy.

Academic Citation:
Stein, A., Huang, F., & Goldstein, T. (2026). GATES: Self-Distillation under Privileged Context with Consensus Gating. arXiv preprint arXiv:2602.20574. https://arxiv.org/abs/2602.20574

This article is an independent editorial analysis of peer-reviewed research submitted to arXiv. The views and commentary expressed here reflect the editorial perspective of this site and do not represent the views of the original authors or their institution. Diagrams are original illustrations created for this article and do not reproduce figures from the paper.

Leave a Comment

Your email address will not be published. Required fields are marked *

Follow by Email
Tiktok