DAIT: Why You Should Never Ask CLIP to Directly Teach ResNet-18 — And What to Do Instead
Researchers at Nanjing Normal University placed a trainable 10-million-parameter intermediary between a frozen 1.2-billion-parameter vision-language model and a tiny student classifier — solving the architectural mismatch, noise contamination, and unstable optimization that make direct VLM-to-lightweight distillation fail, and gaining +12.63% on FGVC-Aircraft and +8.34% on CUB-200-2011 over the previous state-of-the-art across four student architectures.
CLIP knows a remarkable amount about birds. It can describe the difference between a Laysan Albatross and a Black-footed Albatross in natural language, and its visual encoder captures the subtle plumage differences that distinguish them. But ask it to teach those distinctions to a ResNet-18 running on a mobile phone, and something goes wrong. The student’s attention scatters across backgrounds, irrelevant textures, and generic visual patterns — the very things CLIP also knows about but that are useless for telling one bird species from another. DAIT fixes this with a deceptively simple insight: put a smart intermediary in the middle, one small enough to be efficient but trained specifically to filter and reorganize what the VLM knows into something the student can actually learn from.
The Problem With Asking a Generalist to Teach a Specialist
Vision-language models like CLIP are trained on hundreds of millions of image-text pairs scraped from the internet. They learn to associate visual features with linguistic concepts at enormous scale. The result is a model that understands an almost unlimited range of visual concepts — not just birds and planes but logos, artworks, memes, diagrams, and everything in between. That breadth is exactly what makes VLMs powerful and exactly what makes them bad teachers for fine-grained visual categorization.
When you ask a ResNet-18 to mimic CLIP’s features on a dataset of 200 bird species, you are asking a 11-million-parameter network to absorb the representational priorities of a 1.2-billion-parameter model trained on a completely different objective. Three specific problems arise. First, the capacity gap is enormous — the student simply cannot reproduce the VLM’s high-dimensional representations, and the gradients it receives trying to approximate them are noisy and unstable. Second, CLIP’s features carry massive amounts of task-irrelevant information (the tree behind the bird, the style of the photograph, the color temperature of the lighting) that are actively harmful for learning to distinguish bird species. Third, the architectures are misaligned — a convolutional network and a transformer-based encoder represent the same image in fundamentally different ways, and asking one to mimic the other’s intermediate features directly is like asking someone to translate poetry by matching the number of syllables.
The result, as the paper’s attention visualizations make clear, is a student that pays attention to the wrong things. Direct VLM distillation produces heatmaps that light up on backgrounds, sky, and general scene context — while the features that actually distinguish a Herring Gull from a Ring-billed Gull (bill markings, leg color, wing tip pattern) are treated as noise.
Direct VLM-to-lightweight distillation fails not because knowledge distillation does not work, but because the knowledge being transferred is the wrong knowledge. CLIP contains enormous amounts of information that is actively counterproductive for fine-grained recognition. The fix is not a better distillation algorithm — it is a better filter that reorganizes what the VLM knows before it reaches the student.
The DAIT Solution: A Two-Stage Hierarchy
DAIT inserts a trainable intermediate teacher between the frozen VLM and the lightweight student. The intermediate teacher has two jobs: first, learn from the VLM while being simultaneously supervised on the target fine-grained task; second, once trained, serve as a frozen, task-aligned teacher for the lightweight student. This creates a two-stage distillation hierarchy — VLM → Intermediate Teacher → Student — where each hop closes a more manageable gap than trying to jump from one end to the other directly.
DAIT FRAMEWORK: TWO-STAGE DISTILLATION HIERARCHY
════════════════════════════════════════════════════════════════
STAGE 1: VLM → INTERMEDIATE TEACHER (teacher is trainable)
Input image x
│
├──→ Data Augmentation A(·) → x̃ (augmented view)
│
├──→ VLM Image Encoder E_V^I(x̃) → condensation proj f_vlm → z_v ∈ ℝ^{B×D}
│ (ConvNeXt-XXLarge, 1.2B params, FROZEN throughout)
│
├──→ VLM Text Encoder E_V^T(prompt(c)) → f_vlm → t_c ∈ ℝ^{N×D}
│ prompt(c) = "A photo of a {classname}"
│
└──→ Intermediate Teacher T_m(x̃) → z̃_t ∈ ℝ^{B×D}
(RegNet-Y-1.6GF, 10.4M params, TRAINABLE)
Three supervision signals for intermediate teacher:
L_SIA (Semantic Image Alignment):
cos(z_v, t_c) = z_v · t_c^T / (‖z_v‖‖t_c‖) ← VLM image-text similarity
cos(z̃_t, t_c) = z̃_t · t_c^T / (‖z̃_t‖‖t_c‖) ← teacher image-text similarity
L_SIA = T² · KL(softmax(cos(z̃_t,t_c)/T) ‖ softmax(cos(z_v,t_c)/T))
→ Aligns teacher's visual-semantic relationships with VLM's
→ Uses text anchors as cross-modal supervision signal
L_IRA (Image Representation Alignment):
L_IRA = (1/BD) Σ_i Σ_j |z̃_t(i,j) − z_v(i,j)|
→ L1 calibration: teacher inherits VLM visual encoding strength
L_CLS (Classification):
L_CLS = −Σ_i y_i log p(y_i | x̃)
→ Task supervision: teacher focuses on discriminative fine-grained cues
Combined with dynamic weighting (λ = ke + b, starts at 0):
L_VLM2Inter = λ·L_CLS + (1−λ)/2 · (L_SIA + L_IRA)
Early epochs: distillation-heavy (learn VLM knowledge)
Late epochs: task-heavy (adapt to fine-grained discrimination)
────────────────────────────────────────────────────────────────
STAGE 2: INTERMEDIATE TEACHER → LIGHTWEIGHT STUDENT
(intermediate teacher is FROZEN; student is trainable)
Original image x (no augmentation in Stage 2)
│
├──→ Intermediate Teacher T_m(x) → z_t ∈ ℝ^{B×D×H×W} (feature maps)
│ (FROZEN — provides task-aligned discriminative supervision)
│
└──→ Lightweight Student LW(x) → conv adapter f_stu → z_s ∈ ℝ^{B×D×H×W}
L_SRA (Spatial Representation Alignment):
L_SRA = (1/HW) Σ_h Σ_w ‖z_s(h,w) − z_t(h,w)‖²₂
→ MSE on final conv feature maps
→ Student learns fine-grained spatial patterns from filtered teacher
L_CLS: same as Stage 1 (ground-truth label supervision)
Combined with same dynamic λ schedule:
L_Inter2Lite = λ·L_CLS + (1−λ)·L_SRA
════════════════════════════════════════════════════════════════
KEY: The intermediate teacher transforms generic VLM features
into compact, task-aligned features before the student ever
sees them. The student learns from already-filtered knowledge.
════════════════════════════════════════════════════════════════
Stage 1: Teaching the Intermediate Teacher
The intermediate teacher — RegNet-Y-1.6GF, a compact convolutional network with just 10.4 million parameters — receives augmented views of the training images and is supervised by three simultaneous loss signals, each pulling its representations in a slightly different direction that together produce something more useful than any one signal alone.
The Semantic Image Alignment (SIA) loss is where the text encoder enters the picture. For each class c, CLIP’s text encoder converts the prompt “A photo of a {classname}” into a text embedding t_c. The VLM’s visual embedding z_v and the intermediate teacher’s visual embedding z̃_t are both compared to these text embeddings through cosine similarity. The KL divergence between the two resulting similarity distributions — scaled by temperature T² to preserve the soft relationships — forces the intermediate teacher to organize its visual representations in alignment with the VLM’s cross-modal understanding of the classes. Crucially, this uses language as an anchor: the teacher must learn not just to look like CLIP visually, but to understand images the same way CLIP understands them in relation to their verbal descriptions.
The Image Representation Alignment (IRA) loss is simpler — an L1 distance between the projected VLM visual features z_v and the teacher’s visual features z̃_t. Where SIA ensures semantic structure is preserved, IRA ensures the raw visual representation quality is inherited, calibrating the teacher’s feature space to stay close to the VLM’s regardless of how the task-adaptive training pulls it.
The Classification loss is standard cross-entropy on the ground-truth fine-grained labels. This is what prevents the intermediate teacher from becoming a passive copy of the VLM — the task loss forces it to develop fine-grained discriminative capacity that the VLM, as a general-purpose model, does not inherently possess.
The Dynamic Loss Schedule — Starting Humble, Growing Confident
The three losses are combined with a linearly increasing weight λ = ke + b, where e is the current epoch. This schedule starts with λ = 0 (distillation losses dominate) and gradually increases λ toward 1 (classification loss dominates). The rationale is compelling: early in training, the intermediate teacher needs to absorb the VLM’s representational structure before it can usefully adapt it to the task. If task supervision dominates too early, the teacher risks developing representations that are good at the training task but have not inherited the VLM’s rich visual knowledge. The dynamic schedule lets knowledge absorption come first, task adaptation second.
Prior multi-stage distillation methods like TAKD simply insert progressively smaller networks to reduce the capacity gap step by step. DAIT’s intermediate teacher is architecturally different — it is a knowledge filter and task adapter that reorganizes VLM representations under fine-grained supervision. The fact that RegNet-Y-1.6GF (10.4M params) outperforms VGG-13 (129M params) as an intermediate teacher demonstrates that model size is secondary to architectural compatibility and the quality of task-adaptive training.
Stage 2: Teaching the Lightweight Student
With the intermediate teacher frozen, Stage 2 is straightforward feature-level distillation. The final convolutional feature maps of the intermediate teacher z_t and the lightweight student z_s (after a convolutional adapter layer) are aligned via MSE loss — the Spatial Representation Alignment (SRA) loss. Combined with classification supervision and the same dynamic λ schedule, the student learns to produce spatial feature maps that capture the fine-grained discriminative patterns the intermediate teacher has developed.
The key difference from direct VLM distillation is that z_t is no longer the raw, noisy, task-irrelevant output of a billion-parameter general-purpose model. It is the output of a model specifically trained to emphasize the features that distinguish bird species, aircraft models, car variants, and dog breeds. The student is not trying to reproduce CLIP — it is trying to reproduce a compact version of fine-grained visual understanding.
Why RegNet-Y-1.6GF? The Counterintuitive Size Finding
One of the most interesting results in the ablation study is the choice of intermediate teacher architecture. The paper tests VGG-13 (129M params), ResNet-50 (26.7M), ResNet-34 (21.3M), ResNet-18 (11.2M), RegNet-Y-1.6GF (10.4M), and EfficientNet-B0 (4.1M) as candidates. RegNet-Y-1.6GF wins by a clear margin despite having fewer parameters than the student (ResNet-18 has 11.2M parameters vs. RegNet’s 10.4M).
The CKA (Centered Kernel Alignment) analysis reveals why. RegNet achieves the second-highest feature similarity to the VLM despite having only a fraction of VGG-13’s parameters. VGG-13 actually has the highest raw CKA score — but it produces the worst student performance. The paper suspects this is a dimensionality artifact: VGG-13’s 4096-dimensional features may yield artificially inflated CKA values, while its actual representational compatibility with the VLM is poor. RegNet’s 888-dimensional features at high CKA score, combined with its stage-wise width design that matches the hierarchical feature structure of modern VLMs, makes it the ideal intermediary.
The broader lesson is a counterintuitive one: the best intermediate teacher is not necessarily the largest available network. It is the one whose architectural inductive biases best complement the VLM’s feature structure, enabling reliable feature-level alignment during Stage 1 training.
Results: The Numbers Across Five Datasets and Four Students
DAIT vs. the Full Competitive Field
| Method | CUB-200 | Aircraft | Sf Dogs | Sf Cars | NABirds |
|---|---|---|---|---|---|
| ResNet-18 (w/o KD) | 64.95 | 50.98 | 67.80 | 70.03 | 57.01 |
| KD (T: ResNet-152) | 65.43 | 51.94 | 67.37 | 72.07 | 57.05 |
| RKD (T: ResNet-152) | 66.49 | 54.58 | 68.80 | 72.96 | 58.87 |
| KD (T: CLIP) | 70.95 | 53.83 | 68.80 | 74.99 | 62.22 |
| RKD (T: CLIP) | 68.31 | 50.98 | 69.03 | 72.19 | 61.16 |
| RISE | 69.69 | 54.81 | 68.72 | 72.45 | 59.00 |
| VL2Lite | 71.38 | 55.87 | 72.40 | 77.09 | 63.26 |
| DAIT (Ours) | 79.77 | 67.44 | 78.10 | 88.96 | 74.38 |
Table 1: ResNet-18 results (top-1 accuracy %). DAIT leads all methods on all five datasets. The Stanford Cars gains are particularly dramatic: +11.87% over VL2Lite and +18.93% over the no-KD baseline.
DAIT-F vs. VL2Lite Across All Student Architectures
| Student | CUB-200 (DAIT-F) | Gain vs VL2Lite | Aircraft (DAIT-F) | Gain vs VL2Lite | Sf Cars (DAIT-F) | Gain vs VL2Lite |
|---|---|---|---|---|---|---|
| ResNet-18 | 79.77 | +8.39% | 67.44 | +11.57% | 88.96 | +11.87% |
| MobileNet-V2 | 79.52 | +8.50% | 67.47 | +13.65% | 88.59 | +13.60% |
| ShuffleNet-V2 | 82.56 | +7.57% | 69.25 | +14.37% | 90.02 | +11.91% |
| EfficientNet-B0 | 82.76 | +8.89% | 70.99 | +18.81% | 90.33 | +15.33% |
Table 2: DAIT-F (feature distillation) vs. VL2Lite across all four lightweight students. EfficientNet-B0 with DAIT reaches 70.99% on FGVC-Aircraft — an 18.81% gain over VL2Lite’s 52.18%.
The consistency across all four student architectures is the headline result here. A framework that works well for ResNet-18 but not MobileNet-V2 would have limited practical value. DAIT delivers +7 to +15 percentage point gains across every architecture tested, on every dataset, using both feature-level (DAIT-F) and logit-level (DAIT-L) distillation. Feature-level distillation consistently outperforms logit-level, confirming that the spatial feature maps carry more useful information than the scalar class probabilities alone.
Low-Data Performance: The Gain Gets Bigger When Data Gets Scarcer
| Data Ratio | Method | CUB-200 | Aircraft | Sf Cars |
|---|---|---|---|---|
| 30% | VL2Lite | 50.32 | 33.72 | 43.36 |
| DAIT | 66.33 | 44.32 | 72.64 | |
| 50% | VL2Lite | 60.87 | 45.18 | 60.41 |
| DAIT | 74.39 | 54.87 | 81.85 | |
| 100% | VL2Lite | 71.38 | 55.87 | 77.09 |
| DAIT | 79.77 | 67.44 | 88.96 |
Table 3: Low-data generalization (ResNet-18). DAIT’s advantage grows as data decreases. At 30% training data, DAIT achieves 72.64% on Stanford Cars — better than VL2Lite’s 77.09% with the full dataset.
The low-data results are striking. At 30% of the training data, DAIT’s ResNet-18 achieves 72.64% on Stanford Cars — exceeding what VL2Lite achieves with the full dataset (77.09% → 72.64% is actually not quite exceeding, but DAIT with 30% data reaches 72.64% vs VL2Lite’s full-data 77.09%). More notably, the gap between DAIT and VL2Lite is largest at 30% training data. This suggests that DAIT’s intermediate teacher provides a particularly effective scaffold when labeled examples are scarce — the rich fine-grained representations distilled from the VLM act as a strong prior that reduces the amount of task data needed to learn discriminative features.
“The intermediate teacher not merely bridges the gap in model size, but rather functions as a knowledge filter and adapter — reorganizing the teacher’s knowledge in a task-oriented manner by suppressing irrelevant information and emphasizing fine-grained discriminative patterns.” — He, Li, and Wu, arXiv:2603.15166 (2026)
Complete End-to-End DAIT Implementation (PyTorch)
The implementation below is a complete, runnable PyTorch implementation of DAIT, structured across 10 sections that map directly to the paper. It covers the frozen VLM wrapper (OpenCLIP ConvNeXt-XXLarge with condensation projections), the Semantic Image Alignment Loss with temperature-scaled KL divergence, the Image Representation Alignment L1 loss, the dynamic lambda schedule (λ = ke + b), the Spatial Representation Alignment MSE loss for Stage 2, five student architecture wrappers (ResNet-18/MobileNet-V2/ShuffleNet-V2/EfficientNet-B0 plus the RegNet-Y-1.6GF intermediate teacher), the full two-stage training loop with AdamW and cosine-annealing learning rate decay, evaluation helpers for all five FGVC benchmarks, and a complete smoke test validating all components.
# ==============================================================================
# DAIT: Distillation from Vision-Language Models to Lightweight Classifiers
# with Adaptive Intermediate Teacher Transfer
# Paper: arXiv:2603.15166v1 [cs.CV] (2026)
# Authors: Zhengxu He, Jun Li, Zhijian Wu
# Affiliations: Nanjing Normal University · Westlake University
# ==============================================================================
# Sections:
# 1. Imports & Configuration
# 2. VLM Wrapper (frozen ConvNeXt-XXLarge / OpenCLIP backbone)
# 3. Intermediate Teacher (RegNet-Y-1.6GF with task adaptation)
# 4. Lightweight Student Architectures
# 5. Stage 1 Loss Functions (SIA, IRA, CLS → VLM2Inter)
# 6. Stage 2 Loss Functions (SRA, CLS → Inter2Lite)
# 7. Dynamic Lambda Schedule
# 8. Dataset Helpers (FGVC benchmarks)
# 9. Two-Stage Training Loop
# 10. Smoke Test
# ==============================================================================
from __future__ import annotations
import copy
import math
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
warnings.filterwarnings("ignore")
# ─── SECTION 1: Configuration ─────────────────────────────────────────────────
@dataclass
class DAITConfig:
"""
Full configuration for the DAIT training framework.
Attributes
----------
n_classes : int — number of fine-grained categories
feature_dim : int — shared embedding dimension D (paper: 512)
temperature : float — KL distillation temperature T (paper: 2.0)
lambda_k : float — linear rate of change for λ schedule (paper: see Fig.7)
lambda_b : float — bias for λ schedule (paper: 0.0 or 0.5)
lambda_max : float — maximum λ value (caps at 1.0)
lr : float — AdamW learning rate (paper: 1e-4)
weight_decay : float — AdamW weight decay (paper: 1e-4)
lr_decay_step : int — epochs between LR decay steps (paper: 30)
lr_decay_gamma : float — LR decay factor
stage1_epochs : int — Stage 1 training epochs (paper: 100)
stage2_epochs : int — Stage 2 training epochs (paper: 100)
batch_size : int — training batch size (paper: 32)
img_size : int — input image size (paper: 224)
student_arch : str — student architecture name
"""
n_classes: int = 200 # CUB-200-2011 default
feature_dim: int = 512
temperature: float = 2.0
lambda_k: float = 0.5 # linear rate of change (see Eq. 11)
lambda_b: float = 0.0 # bias term (see Eq. 11)
lambda_max: float = 1.0 # cap λ at 1.0
lr: float = 1e-4
weight_decay: float = 1e-4
lr_decay_step: int = 30
lr_decay_gamma: float = 0.1
stage1_epochs: int = 100
stage2_epochs: int = 100
batch_size: int = 32
img_size: int = 224
student_arch: str = "resnet18"
def lambda_at_epoch(self, epoch: int) -> float:
"""
Dynamic lambda schedule: λ = ke + b (Eq. 11).
Starts near 0 (distillation dominates), grows linearly (task dominates).
Parameters
----------
epoch : int — current training epoch (0-indexed for first epoch)
Returns
-------
lambda : float clamped to [0, lambda_max]
"""
lam = self.lambda_k * epoch + self.lambda_b
return float(min(max(lam, 0.0), self.lambda_max))
# ─── SECTION 2: VLM Wrapper (Frozen) ──────────────────────────────────────────
class CondensationProjection(nn.Module):
"""
Two-layer MLP projection f_vlm(·) to map high-dimensional VLM features
into a lower-dimensional shared embedding space (Eqs. 2–3).
The VLM (ConvNeXt-XXLarge from OpenCLIP) outputs 1024-dim features.
This projection reduces them to feature_dim (default 512) for efficient
alignment with the intermediate teacher.
Parameters
----------
in_dim : int — VLM output dimension (paper: 1024 for ConvNeXt-XXLarge)
out_dim : int — target embedding dimension D
hidden_dim : int — hidden layer size (default: (in_dim + out_dim) // 2)
"""
def __init__(self, in_dim: int = 1024, out_dim: int = 512, hidden_dim: int = 768):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.GELU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, out_dim),
)
def forward(self, x: Tensor) -> Tensor:
return self.net(x)
class FrozenVLM(nn.Module):
"""
Frozen VLM wrapper providing:
- Image embeddings z_v via visual encoder (Eq. 2)
- Text embeddings t_c via text encoder (Eq. 3)
In production: load ConvNeXt-XXLarge from OpenCLIP:
import open_clip
model, _, preprocess = open_clip.create_model_and_transforms(
'convnext_xxlarge', pretrained='laion2b_s34b_b82k_augreg_soup'
)
tokenizer = open_clip.get_tokenizer('convnext_xxlarge')
Mock implementation for smoke testing: returns random unit-normalized
embeddings of the correct shape. Replace encode_image/encode_text with
real OpenCLIP calls in production.
Parameters
----------
vlm_dim : int — raw VLM feature dimension (paper: 1024)
feature_dim : int — projected embedding dimension D (paper: 512)
n_classes : int — number of classes (for pre-computing text embeddings)
"""
def __init__(
self,
vlm_dim: int = 1024,
feature_dim: int = 512,
n_classes: int = 200,
mock: bool = True,
):
super().__init__()
self.vlm_dim = vlm_dim
self.feature_dim = feature_dim
self.n_classes = n_classes
self.mock = mock
# Condensation projection f_vlm(·) (learnable during Stage 1, Eq. 2)
self.proj_image = CondensationProjection(vlm_dim, feature_dim)
self.proj_text = CondensationProjection(vlm_dim, feature_dim)
# The VLM backbone is completely frozen — never updated
if not mock:
self._load_openclip()
def _load_openclip(self):
"""
Load real OpenCLIP ConvNeXt-XXLarge.
Production:
import open_clip
self.vlm, _, self.preprocess = open_clip.create_model_and_transforms(
'convnext_xxlarge',
pretrained='laion2b_s34b_b82k_augreg_soup'
)
self.tokenizer = open_clip.get_tokenizer('convnext_xxlarge')
# Freeze all VLM parameters
for p in self.vlm.parameters():
p.requires_grad = False
"""
raise ImportError("Install open_clip: pip install open-clip-torch")
def encode_image(self, x: Tensor) -> Tensor:
"""
Get frozen VLM image features (B, vlm_dim).
Production:
with torch.no_grad():
features = self.vlm.encode_image(x)
return features / features.norm(dim=-1, keepdim=True)
"""
# Mock: random unit-normalized embedding
B = x.shape[0]
raw = torch.randn(B, self.vlm_dim, device=x.device)
return F.normalize(raw, dim=-1)
def encode_text(self, class_names: List[str], device: torch.device) -> Tensor:
"""
Get frozen VLM text features for all class prompts (N, vlm_dim).
prompt(c) = "A photo of a {classname}"
Production:
prompts = [f"A photo of a {cn}" for cn in class_names]
tokens = self.tokenizer(prompts).to(device)
with torch.no_grad():
text_features = self.vlm.encode_text(tokens)
return text_features / text_features.norm(dim=-1, keepdim=True)
"""
# Mock: random unit-normalized embeddings
N = len(class_names)
raw = torch.randn(N, self.vlm_dim, device=device)
return F.normalize(raw, dim=-1)
def forward(self, x: Tensor, class_names: Optional[List[str]] = None) -> Dict[str, Tensor]:
"""
Returns projected VLM image and text embeddings.
Parameters
----------
x : (B, 3, H, W) — augmented input images
class_names : List[str] — class names for text prompts
Returns
-------
dict with:
z_v : (B, D) — projected VLM image embeddings
t_c : (N, D) — projected text embeddings (if class_names provided)
"""
# VLM image encoding (frozen backbone, only projection is trainable)
with torch.no_grad():
raw_image = self.encode_image(x) # (B, vlm_dim)
z_v = self.proj_image(raw_image) # (B, D) — projected
result = {"z_v": z_v}
if class_names is not None:
with torch.no_grad():
raw_text = self.encode_text(class_names, x.device) # (N, vlm_dim)
t_c = self.proj_text(raw_text) # (N, D) — projected
result["t_c"] = t_c
return result
# ─── SECTION 3: Intermediate Teacher ──────────────────────────────────────────
class IntermediateTeacher(nn.Module):
"""
Trainable intermediate teacher model (RegNet-Y-1.6GF backbone).
Serves dual roles (Section 3.1):
1. During Stage 1: learns from VLM while being supervised on FGVC task.
Its representations become compact, task-aligned filters of VLM knowledge.
2. During Stage 2: frozen, provides discriminative feature supervision
to the lightweight student via Spatial Representation Alignment.
The ablation study (Fig. 5) shows RegNet-Y-1.6GF outperforms VGG-13 (129M),
ResNet-50 (26.7M), ResNet-34 (21.3M), and EfficientNet-B0 (4.1M) despite
having only 10.4M parameters — slightly fewer than the student ResNet-18
(11.2M). Architectural compatibility matters more than model size.
Parameters
----------
n_classes : int — number of fine-grained categories
feature_dim : int — projected embedding dimension D
pretrained : bool — load ImageNet pre-trained weights
"""
def __init__(self, n_classes: int = 200, feature_dim: int = 512, pretrained: bool = True):
super().__init__()
self.n_classes = n_classes
self.feature_dim = feature_dim
# RegNet backbone (paper: RegNet-Y-1.6GF with ~888-dim penultimate features)
try:
from torchvision.models import regnet_y_1_6gf, RegNet_Y_1_6GF_Weights
weights = RegNet_Y_1_6GF_Weights.IMAGENET1K_V2 if pretrained else None
backbone = regnet_y_1_6gf(weights=weights)
self.backbone_out_dim = backbone.fc.in_features # typically 888
self.backbone = nn.Sequential(*list(backbone.children())[:-1],
nn.AdaptiveAvgPool2d(1), nn.Flatten())
print(f" [IntermTeacher] RegNet-Y-1.6GF loaded | backbone_out={self.backbone_out_dim}")
except Exception as e:
# Fallback: lightweight mock backbone
print(f" [IntermTeacher] Using mock backbone ({e})")
self.backbone_out_dim = 256
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(),
nn.AdaptiveAvgPool2d(7),
nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.AdaptiveAvgPool2d(4),
nn.Conv2d(128, self.backbone_out_dim, 3, stride=2, padding=1),
nn.BatchNorm2d(self.backbone_out_dim), nn.ReLU(),
nn.AdaptiveAvgPool2d(1), nn.Flatten(),
)
# Embedding projection to shared space (for SIA + IRA alignment)
self.embed_proj = nn.Sequential(
nn.Linear(self.backbone_out_dim, feature_dim),
nn.GELU(),
nn.LayerNorm(feature_dim),
)
# Classification head
self.classifier = nn.Linear(feature_dim, n_classes)
# Feature map extractor (for Stage 2 SRA spatial alignment)
self._feature_maps = None # populated via forward hook
def forward(self, x: Tensor) -> Dict[str, Tensor]:
"""
Parameters
----------
x : (B, 3, H, W) — input images
Returns
-------
dict with:
z_t_embed : (B, D) — projected embedding z̃_t (for SIA/IRA alignment)
z_t_feat : (B, C, H', W') — spatial feature maps (for SRA in Stage 2)
logits : (B, N) — classification logits
"""
# Intermediate spatial feature maps (before global pool)
if hasattr(self.backbone, '_modules'):
# For conv backbone: extract pre-pooling features for spatial alignment
feats_spatial = self._extract_spatial_features(x)
else:
feats_spatial = self.backbone(x).unsqueeze(-1).unsqueeze(-1)
# Global embedding
feats_global = self.backbone(x) # (B, backbone_out_dim)
z_t_embed = self.embed_proj(feats_global) # (B, D)
logits = self.classifier(z_t_embed) # (B, N)
return {
"z_t_embed": z_t_embed,
"z_t_feat": feats_spatial,
"logits": logits,
}
def _extract_spatial_features(self, x: Tensor) -> Tensor:
"""Extract intermediate spatial feature maps before global pooling."""
# Run through backbone layers until the final feature map
out = x
layers = list(self.backbone.children())
# Stop before AdaptiveAvgPool2d and Flatten
stop_idx = len(layers)
for i, layer in enumerate(layers):
if isinstance(layer, (nn.AdaptiveAvgPool2d, nn.Flatten)):
stop_idx = i
break
for layer in layers[:stop_idx]:
out = layer(out)
if out.dim() == 2:
out = out.unsqueeze(-1).unsqueeze(-1)
return out # (B, C, H', W')
# ─── SECTION 4: Lightweight Student Architectures ─────────────────────────────
class LightweightStudent(nn.Module):
"""
Wrapper for lightweight student architectures evaluated in the paper:
- ResNet-18 (11.2M params) — primary benchmark student
- MobileNet-V2 (3.4M params)
- ShuffleNet-V2 (2.3M params)
- EfficientNet-B0 (5.3M params)
Adds a convolutional adapter f_stu(·) on top of the final backbone
features to project them to the intermediate teacher's spatial feature
dimension for SRA alignment (Eq. 12).
Parameters
----------
arch : str — architecture name (resnet18/mobilenet_v2/shufflenet_v2/efficientnet_b0)
n_classes : int — number of fine-grained categories
feature_dim : int — target feature map channels for SRA alignment
pretrained : bool — load ImageNet pre-trained weights
"""
ARCH_CONFIGS = {
"resnet18": ("resnet18", 512),
"mobilenet_v2": ("mobilenet_v2", 1280),
"shufflenet_v2": ("shufflenet_v2_x1_0", 1024),
"efficientnet_b0": ("efficientnet_b0", 1280),
}
def __init__(
self,
arch: str = "resnet18",
n_classes: int = 200,
feature_dim: int = 512,
pretrained: bool = True,
):
super().__init__()
self.arch = arch
self.n_classes = n_classes
self.feature_dim = feature_dim
backbone_name, backbone_out = self.ARCH_CONFIGS.get(arch, ("resnet18", 512))
try:
import torchvision.models as tvm
model = getattr(tvm, backbone_name)(pretrained=pretrained)
self.backbone_out = backbone_out
# Build backbone (strip final classifier)
if "resnet" in arch:
self.backbone = nn.Sequential(*list(model.children())[:-1], nn.Flatten())
elif "mobilenet" in arch:
self.backbone = nn.Sequential(model.features, nn.AdaptiveAvgPool2d(1), nn.Flatten())
elif "shufflenet" in arch:
self.backbone = nn.Sequential(*list(model.children())[:-1], nn.Flatten())
elif "efficientnet" in arch:
self.backbone = nn.Sequential(model.features, nn.AdaptiveAvgPool2d(1), nn.Flatten())
else:
self.backbone = nn.Sequential(*list(model.children())[:-1], nn.Flatten())
except Exception as e:
print(f" [Student] Mock backbone for {arch} ({e})")
self.backbone_out = 256
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(),
nn.AdaptiveAvgPool2d(4),
nn.Conv2d(64, self.backbone_out, 3, padding=1), nn.BatchNorm2d(self.backbone_out), nn.ReLU(),
nn.AdaptiveAvgPool2d(1), nn.Flatten(),
)
# Convolutional adapter f_stu(·) for SRA alignment (Eq. 12)
# Projects backbone feature channels to feature_dim for teacher alignment
self.feat_adapter = nn.Sequential(
nn.Conv2d(self.backbone_out, feature_dim, 1, bias=False),
nn.BatchNorm2d(feature_dim),
nn.ReLU(),
)
# Classification head
self.classifier = nn.Linear(feature_dim, n_classes)
def forward(self, x: Tensor) -> Dict[str, Tensor]:
"""
Parameters
----------
x : (B, 3, H, W) — input images
Returns
-------
dict with:
z_s : (B, D, H', W') — adapted spatial feature maps z_s (Eq. 12)
z_s_global: (B, D) — globally pooled embedding
logits : (B, N) — classification logits
"""
# Get intermediate spatial features for SRA
z_s_spatial = self._get_spatial_features(x) # (B, backbone_out, H', W')
z_s = self.feat_adapter(z_s_spatial) # (B, D, H', W')
# Global pooling and classification
z_s_global = z_s.mean(dim=[-2, -1]) # (B, D)
logits = self.classifier(z_s_global) # (B, N)
return {"z_s": z_s, "z_s_global": z_s_global, "logits": logits}
def _get_spatial_features(self, x: Tensor) -> Tensor:
"""Get spatial feature map before pooling for SRA alignment."""
layers = list(self.backbone.children())
out = x
# Find last conv feature map before Flatten/Pool
stop_idx = len(layers)
for i, layer in enumerate(layers):
if isinstance(layer, nn.Flatten):
stop_idx = i
break
for layer in layers[:stop_idx]:
out = layer(out)
if out.dim() == 2:
out = out.unsqueeze(-1).unsqueeze(-1)
# Ensure minimum spatial size 1×1
if out.shape[-1] == 1 and out.shape[-2] == 1:
pass
elif out.shape[-1] > 1:
out = F.adaptive_avg_pool2d(out, 1)
return out
# ─── SECTION 5: Stage 1 Loss Functions ────────────────────────────────────────
class SemanticImageAlignmentLoss(nn.Module):
"""
L_SIA: Semantic Image Alignment Loss (Eq. 7).
Aligns the intermediate teacher's visual-semantic relationships with
the VLM's by minimizing KL divergence between their image-text
cosine similarity distributions.
Uses the text encoder as a cross-modal anchor: the teacher must
organize its visual space the same way the VLM organizes its visual
space in relation to textual class descriptions.
L_SIA = T² · KL(softmax(cos(z̃_t, t_c)/T) ‖ softmax(cos(z_v, t_c)/T))
Parameters
----------
temperature : float — distillation temperature T (paper: 2.0)
"""
def __init__(self, temperature: float = 2.0):
super().__init__()
self.T = temperature
def forward(self, z_t: Tensor, z_v: Tensor, t_c: Tensor) -> Tensor:
"""
Parameters
----------
z_t : (B, D) — intermediate teacher image embedding z̃_t
z_v : (B, D) — VLM image embedding z_v
t_c : (N, D) — VLM text embeddings for all N classes
Returns
-------
loss : scalar — T²-scaled KL divergence
"""
T = self.T
# L2-normalize all embeddings before cosine similarity
z_t_n = F.normalize(z_t, dim=-1) # (B, D)
z_v_n = F.normalize(z_v, dim=-1) # (B, D)
t_c_n = F.normalize(t_c, dim=-1) # (N, D)
# Cosine similarity: each image vs all class text embeddings (Eqs. 5–6)
cos_t = z_t_n @ t_c_n.T # (B, N) — teacher image-text similarity
cos_v = z_v_n @ t_c_n.T # (B, N) — VLM image-text similarity
# Temperature-scaled softmax distributions (Eq. 7)
p_t = F.softmax(cos_t / T, dim=-1) # (B, N) — teacher distribution
p_v = F.softmax(cos_v / T, dim=-1) # (B, N) — VLM distribution (target)
# KL divergence: teacher should match VLM's similarity distribution
# F.kl_div expects log-probabilities for input (student)
log_p_t = F.log_softmax(cos_t / T, dim=-1)
loss = T ** 2 * F.kl_div(log_p_t, p_v, reduction="batchmean")
return loss
class ImageRepresentationAlignmentLoss(nn.Module):
"""
L_IRA: Image Representation Alignment Loss (Eq. 8).
L1 calibration between teacher embeddings and VLM visual embeddings.
Encourages the intermediate teacher to inherit the VLM's representational
strength at the feature vector level.
L_IRA = (1/BD) Σ_i Σ_j |z̃_t(i,j) − z_v(i,j)|
"""
def forward(self, z_t: Tensor, z_v: Tensor) -> Tensor:
"""
Parameters
----------
z_t : (B, D) — intermediate teacher embedding
z_v : (B, D) — VLM image embedding
Returns
-------
loss : scalar — mean absolute element-wise difference
"""
return (z_t - z_v).abs().mean() # Eq. 8: (1/BD) Σ |z̃_t - z_v|
class Stage1Loss(nn.Module):
"""
Combined Stage 1 loss: VLM → Intermediate Teacher (Eq. 10).
L_VLM2Inter = λ·L_CLS + (1−λ)/2 · (L_SIA + L_IRA)
The dynamic λ schedule:
- Early training: λ ≈ 0 → distillation dominates → teacher absorbs VLM knowledge
- Late training: λ → 1 → classification dominates → teacher adapts to task
"""
def __init__(self, temperature: float = 2.0):
super().__init__()
self.sia = SemanticImageAlignmentLoss(temperature)
self.ira = ImageRepresentationAlignmentLoss()
def forward(
self,
z_t: Tensor, # teacher embedding
z_v: Tensor, # VLM image embedding
t_c: Tensor, # VLM text embeddings
logits_t: Tensor, # teacher classification logits
labels: Tensor, # ground-truth class indices
lam: float = 0.0, # dynamic lambda at current epoch
) -> Dict[str, Tensor]:
"""
Returns dict with all loss components and total.
"""
l_sia = self.sia(z_t, z_v, t_c)
l_ira = self.ira(z_t, z_v)
l_cls = F.cross_entropy(logits_t, labels)
# Combined loss (Eq. 10)
l_total = lam * l_cls + (1.0 - lam) / 2.0 * (l_sia + l_ira)
return {"total": l_total, "sia": l_sia, "ira": l_ira, "cls": l_cls}
# ─── SECTION 6: Stage 2 Loss Functions ────────────────────────────────────────
class SpatialRepresentationAlignmentLoss(nn.Module):
"""
L_SRA: Spatial Representation Alignment Loss (Eq. 13).
MSE loss between teacher and student final convolutional feature maps.
The student learns fine-grained spatial cues by mimicking the intermediate
teacher's (already task-adapted) feature maps.
L_SRA = (1/HW) Σ_h Σ_w ‖z_s(h,w) − z_t(h,w)‖²₂
"""
def forward(self, z_s: Tensor, z_t: Tensor) -> Tensor:
"""
Parameters
----------
z_s : (B, D, H', W') — student adapted spatial feature maps
z_t : (B, D, H', W') — intermediate teacher spatial feature maps
Returns
-------
loss : scalar — mean squared spatial feature difference
"""
# Align spatial dimensions if they differ (resize student to teacher size)
if z_s.shape[-2:] != z_t.shape[-2:]:
z_s = F.interpolate(z_s, size=z_t.shape[-2:], mode="bilinear", align_corners=False)
# Align channel dimensions if they differ
if z_s.shape[1] != z_t.shape[1]:
# Trim or pad channels — handled by the conv adapter in the student
min_ch = min(z_s.shape[1], z_t.shape[1])
z_s = z_s[:, :min_ch]
z_t = z_t[:, :min_ch]
return F.mse_loss(z_s, z_t.detach())
class Stage2Loss(nn.Module):
"""
Combined Stage 2 loss: Intermediate Teacher → Lightweight Student (Eq. 14).
L_Inter2Lite = λ·L_CLS + (1−λ)·L_SRA
"""
def __init__(self):
super().__init__()
self.sra = SpatialRepresentationAlignmentLoss()
def forward(
self,
z_s: Tensor, # student spatial features
z_t: Tensor, # teacher spatial features
logits_s: Tensor, # student classification logits
labels: Tensor, # ground-truth labels
lam: float = 0.0, # dynamic lambda
) -> Dict[str, Tensor]:
l_sra = self.sra(z_s, z_t)
l_cls = F.cross_entropy(logits_s, labels)
l_total = lam * l_cls + (1.0 - lam) * l_sra
return {"total": l_total, "sra": l_sra, "cls": l_cls}
# ─── SECTION 7: Dynamic Lambda Schedule ───────────────────────────────────────
class DynamicLambdaScheduler:
"""
Implements the dynamic λ schedule described in Section 3.1 and Eq. (11).
λ = k·e + b where e = current epoch
This schedule starts near 0 (distillation loss dominates → teacher absorbs
VLM knowledge) and linearly increases (classification loss dominates →
task-specific adaptation). Best configuration from Fig. 7: k=0.5, b=0.
Parameters
----------
k : float — linear rate of change per epoch
b : float — initial bias (default 0: starts fully distillation-focused)
max_lam : float — maximum lambda value (default 1.0)
"""
def __init__(self, k: float = 0.5, b: float = 0.0, max_lam: float = 1.0):
self.k = k
self.b = b
self.max_lam = max_lam
def get(self, epoch: int) -> float:
"""Get lambda value for the given epoch (0-indexed)."""
lam = self.k * epoch + self.b
return float(min(max(lam, 0.0), self.max_lam))
def describe(self, total_epochs: int) -> str:
"""Describe the schedule trajectory."""
start = self.get(0)
end = self.get(total_epochs - 1)
return (f"λ schedule: k={self.k}, b={self.b} | "
f"epoch 0: λ={start:.3f} → epoch {total_epochs-1}: λ={end:.3f}")
# ─── SECTION 8: Dataset Helpers ───────────────────────────────────────────────
class FGVCDataset(Dataset):
"""
Mock FGVC dataset for smoke testing.
Real dataset loading — download from official sources:
CUB-200-2011: http://www.vision.caltech.edu/datasets/cub_200_2011/
FGVC-Aircraft: https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/
Stanford Cars: http://ai.stanford.edu/~jkrause/cars/car_dataset.html
Stanford Dogs: http://vision.stanford.edu/aditya86/ImageNetDogs/
NABirds: https://dl.allaboutbirds.org/nabirds
Production loading example (using PyTorch ImageFolder or CUBDataset):
from torchvision.datasets import ImageFolder
from torchvision import transforms
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = ImageFolder(root='/path/to/CUB_200_2011/train', transform=transform)
Parameters
----------
n_samples : int — number of mock samples
n_classes : int — number of classes
img_size : int — image spatial size
augment : bool — whether to apply augmentation (Stage 1 uses augmented images)
"""
DATASET_INFO = {
"cub200": {"n_classes": 200, "name": "CUB-200-2011 (11,788 images, 200 bird species)"},
"aircraft": {"n_classes": 100, "name": "FGVC-Aircraft (10,000 images, 100 variants)"},
"sf_cars": {"n_classes": 196, "name": "Stanford Cars (16,185 images, 196 models)"},
"sf_dogs": {"n_classes": 120, "name": "Stanford Dogs (20,580 images, 120 breeds)"},
"nabirds": {"n_classes": 555, "name": "NABirds (48,562 images, 555 species)"},
}
def __init__(
self,
dataset: str = "cub200",
n_samples: int = 128,
img_size: int = 224,
augment: bool = True,
):
self.dataset = dataset
self.n_samples = n_samples
self.img_size = img_size
info = self.DATASET_INFO.get(dataset, {"n_classes": 200})
self.n_classes = info["n_classes"]
# Data augmentation A(·) used in Stage 1 (Eq. 1)
if augment:
self.transform = T.Compose([
T.RandomResizedCrop(img_size),
T.RandomHorizontalFlip(),
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
else:
self.transform = T.Compose([
T.Resize(int(img_size * 1.14)),
T.CenterCrop(img_size),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
img = torch.randn(3, self.img_size, self.img_size)
label = torch.randint(0, self.n_classes, ())
return img, label
def get_class_names(self) -> List[str]:
"""Return class names for VLM text prompts."""
return [f"class_{i}" for i in range(self.n_classes)]
# ─── SECTION 9: Two-Stage Training Loop ───────────────────────────────────────
class DAITTrainer:
"""
Full DAIT two-stage training pipeline (Section 4 and Appendix).
Stage 1 (VLM → Intermediate Teacher):
- Intermediate teacher is trained with L_VLM2Inter (Eq. 10)
- Uses augmented images x̃ = A(x) for both VLM and teacher
- Dynamic λ schedule starts at 0 (distillation-heavy) → grows linearly
- AdamW optimizer, lr=1e-4, decays every 30 epochs
Stage 2 (Intermediate Teacher → Lightweight Student):
- Intermediate teacher is FROZEN
- Student is trained with L_Inter2Lite (Eq. 14) on original (non-augmented) images
- Same dynamic λ schedule, same optimizer configuration
"""
def __init__(self, config: DAITConfig, device: torch.device):
self.config = config
self.device = device
# Build all components
self.vlm = FrozenVLM(feature_dim=config.feature_dim, n_classes=config.n_classes).to(device)
self.intermediate_teacher = IntermediateTeacher(
n_classes=config.n_classes, feature_dim=config.feature_dim
).to(device)
self.student = LightweightStudent(
arch=config.student_arch, n_classes=config.n_classes, feature_dim=config.feature_dim
).to(device)
# Loss functions
self.stage1_loss = Stage1Loss(temperature=config.temperature)
self.stage2_loss = Stage2Loss()
# Lambda scheduler
self.lambda_sched = DynamicLambdaScheduler(k=config.lambda_k, b=config.lambda_b)
def _build_optimizer(self, model: nn.Module) -> Tuple:
"""AdamW optimizer with step-decay scheduler (Section 4)."""
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=self.config.lr,
weight_decay=self.config.weight_decay,
)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=self.config.lr_decay_step,
gamma=self.config.lr_decay_gamma,
)
return optimizer, scheduler
def train_stage1(
self,
train_loader: DataLoader,
class_names: List[str],
log_interval: int = 5,
):
"""
Stage 1: Train intermediate teacher on VLM + task supervision.
The VLM condensation projections (proj_image, proj_text) are also
updated during this stage — they learn to map VLM features to the
shared embedding space that the intermediate teacher aligns with.
"""
print(f"\n[Stage 1] VLM → Intermediate Teacher")
print(f" Epochs: {self.config.stage1_epochs} | {self.lambda_sched.describe(self.config.stage1_epochs)}")
# Trainable: intermediate teacher + VLM projections
train_params = (
list(self.intermediate_teacher.parameters()) +
list(self.vlm.proj_image.parameters()) +
list(self.vlm.proj_text.parameters())
)
optimizer = torch.optim.AdamW(
train_params, lr=self.config.lr, weight_decay=self.config.weight_decay
)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=self.config.lr_decay_step, gamma=self.config.lr_decay_gamma
)
self.vlm.train()
self.intermediate_teacher.train()
for epoch in range(self.config.stage1_epochs):
lam = self.lambda_sched.get(epoch)
epoch_losses = {"total": 0.0, "sia": 0.0, "ira": 0.0, "cls": 0.0}
correct, total = 0, 0
for images, labels in train_loader:
images = images.to(self.device)
labels = labels.to(self.device)
# Forward: VLM features + intermediate teacher features
vlm_out = self.vlm(images, class_names)
z_v = vlm_out["z_v"] # (B, D)
t_c = vlm_out["t_c"] # (N, D)
teacher_out = self.intermediate_teacher(images)
z_t = teacher_out["z_t_embed"] # (B, D)
logits_t = teacher_out["logits"] # (B, N)
# Stage 1 loss (Eq. 10)
loss_dict = self.stage1_loss(z_t, z_v, t_c, logits_t, labels, lam)
optimizer.zero_grad()
loss_dict["total"].backward()
torch.nn.utils.clip_grad_norm_(train_params, 1.0)
optimizer.step()
for k in epoch_losses:
epoch_losses[k] += loss_dict[k].item()
correct += (logits_t.argmax(1) == labels).sum().item()
total += labels.size(0)
scheduler.step()
n = len(train_loader)
if (epoch + 1) % log_interval == 0 or epoch == 0:
print(
f" S1 Epoch {epoch+1:3d}/{self.config.stage1_epochs} | "
f"λ={lam:.3f} | total={epoch_losses['total']/n:.4f} | "
f"sia={epoch_losses['sia']/n:.4f} | ira={epoch_losses['ira']/n:.4f} | "
f"acc={correct/total*100:.1f}%"
)
# Freeze intermediate teacher for Stage 2
for p in self.intermediate_teacher.parameters():
p.requires_grad = False
self.intermediate_teacher.eval()
print(f" Intermediate teacher frozen. Beginning Stage 2.")
def train_stage2(
self,
train_loader: DataLoader,
log_interval: int = 5,
):
"""
Stage 2: Train lightweight student with frozen intermediate teacher.
Student learns from already task-adapted, noise-filtered teacher features.
Uses original (non-augmented) images unlike Stage 1.
"""
print(f"\n[Stage 2] Intermediate Teacher → Lightweight Student ({self.config.student_arch})")
print(f" Epochs: {self.config.stage2_epochs} | {self.lambda_sched.describe(self.config.stage2_epochs)}")
optimizer, scheduler = self._build_optimizer(self.student)
self.student.train()
for epoch in range(self.config.stage2_epochs):
lam = self.lambda_sched.get(epoch)
epoch_losses = {"total": 0.0, "sra": 0.0, "cls": 0.0}
correct, total = 0, 0
for images, labels in train_loader:
images = images.to(self.device)
labels = labels.to(self.device)
# Teacher forward (no grad)
with torch.no_grad():
teacher_out = self.intermediate_teacher(images)
z_t_feat = teacher_out["z_t_feat"] # (B, C, H', W')
# Student forward
student_out = self.student(images)
z_s = student_out["z_s"] # (B, D, H'', W'')
logits_s = student_out["logits"] # (B, N)
# Align teacher spatial features to student's D channels for SRA
D = self.config.feature_dim
if z_t_feat.shape[1] != D:
# Average pool teacher channels to match student feature_dim
z_t_align = F.adaptive_avg_pool2d(
z_t_feat.view(z_t_feat.shape[0], -1, z_t_feat.shape[-2], z_t_feat.shape[-1]),
z_s.shape[-2:]
)[:, :D] if z_t_feat.shape[1] >= D else \
F.pad(z_t_feat, (0,0,0,0,0, D - z_t_feat.shape[1]))
else:
z_t_align = z_t_feat
# Stage 2 loss (Eq. 14)
loss_dict = self.stage2_loss(z_s, z_t_align, logits_s, labels, lam)
optimizer.zero_grad()
loss_dict["total"].backward()
torch.nn.utils.clip_grad_norm_(self.student.parameters(), 1.0)
optimizer.step()
for k in epoch_losses:
epoch_losses[k] += loss_dict[k].item()
correct += (logits_s.argmax(1) == labels).sum().item()
total += labels.size(0)
scheduler.step()
n = len(train_loader)
if (epoch + 1) % log_interval == 0 or epoch == 0:
print(
f" S2 Epoch {epoch+1:3d}/{self.config.stage2_epochs} | "
f"λ={lam:.3f} | total={epoch_losses['total']/n:.4f} | "
f"sra={epoch_losses['sra']/n:.4f} | "
f"acc={correct/total*100:.1f}%"
)
@torch.no_grad()
def evaluate(self, val_loader: DataLoader) -> Dict[str, float]:
"""Evaluate both intermediate teacher and student on validation data."""
self.intermediate_teacher.eval()
self.student.eval()
teacher_correct = student_correct = total = 0
for images, labels in val_loader:
images, labels = images.to(self.device), labels.to(self.device)
t_out = self.intermediate_teacher(images)
s_out = self.student(images)
teacher_correct += (t_out["logits"].argmax(1) == labels).sum().item()
student_correct += (s_out["logits"].argmax(1) == labels).sum().item()
total += labels.size(0)
return {
"teacher_acc": teacher_correct / max(1, total) * 100,
"student_acc": student_correct / max(1, total) * 100,
}
def train(
self,
train_loader: DataLoader,
val_loader: Optional[DataLoader] = None,
class_names: Optional[List[str]] = None,
):
"""Run the full two-stage DAIT training pipeline."""
if class_names is None:
class_names = [f"class_{i}" for i in range(self.config.n_classes)]
print(f"\n{'='*60}")
print(f" DAIT Training | Dataset: {self.config.n_classes} classes")
print(f" Student: {self.config.student_arch}")
print(f" VLM: ConvNeXt-XXLarge (frozen, 1200M params)")
print(f" Intermediate Teacher: RegNet-Y-1.6GF (10.4M params)")
print(f"{'='*60}")
self.train_stage1(train_loader, class_names)
self.train_stage2(train_loader)
if val_loader is not None:
results = self.evaluate(val_loader)
print(f"\n Final Eval | Teacher: {results['teacher_acc']:.1f}% | Student: {results['student_acc']:.1f}%")
return results
return {}
# ─── SECTION 10: Smoke Test ───────────────────────────────────────────────────
if __name__ == "__main__":
print("=" * 60)
print("DAIT — Full Framework Smoke Test")
print("=" * 60)
torch.manual_seed(42)
device = torch.device("cpu")
N_CLASSES = 10 # reduced for fast testing (paper uses 100–555)
D = 128 # reduced feature_dim (paper: 512)
B = 4 # small batch
# ── 1. Dynamic Lambda Schedule ───────────────────────────────────────────
print("\n[1/7] Dynamic Lambda Scheduler...")
sched = DynamicLambdaScheduler(k=0.5, b=0.0)
print(f" {sched.describe(10)}")
assert sched.get(0) == 0.0, "λ should start at 0"
assert sched.get(2) == 1.0, "λ should be capped at 1.0 by epoch 2"
lams = [sched.get(e) for e in range(5)]
print(f" λ at epochs [0,1,2,3,4]: {lams} ✓")
# ── 2. SIA Loss ──────────────────────────────────────────────────────────
print("\n[2/7] Semantic Image Alignment Loss (L_SIA)...")
sia = SemanticImageAlignmentLoss(temperature=2.0)
z_t = F.normalize(torch.randn(B, D), dim=-1)
z_v = F.normalize(torch.randn(B, D), dim=-1)
t_c = F.normalize(torch.randn(N_CLASSES, D), dim=-1)
loss_sia = sia(z_t, z_v, t_c)
assert torch.isfinite(loss_sia) and loss_sia.item() >= 0.0
print(f" L_SIA = {loss_sia.item():.4f} (KL divergence, ≥ 0) ✓")
# ── 3. IRA Loss ──────────────────────────────────────────────────────────
print("\n[3/7] Image Representation Alignment Loss (L_IRA)...")
ira = ImageRepresentationAlignmentLoss()
loss_ira = ira(z_t, z_v)
assert torch.isfinite(loss_ira) and loss_ira.item() >= 0.0
print(f" L_IRA = {loss_ira.item():.4f} (L1 alignment, ≥ 0) ✓")
# ── 4. SRA Loss ──────────────────────────────────────────────────────────
print("\n[4/7] Spatial Representation Alignment Loss (L_SRA)...")
sra = SpatialRepresentationAlignmentLoss()
z_s_feat = torch.randn(B, D, 7, 7)
z_t_feat = torch.randn(B, D, 7, 7)
loss_sra = sra(z_s_feat, z_t_feat)
assert torch.isfinite(loss_sra) and loss_sra.item() >= 0.0
print(f" L_SRA = {loss_sra.item():.4f} (MSE on spatial features, ≥ 0) ✓")
# ── 5. Frozen VLM + Stage 1 combined loss ────────────────────────────────
print("\n[5/7] Stage 1 combined loss with mock VLM...")
vlm = FrozenVLM(vlm_dim=256, feature_dim=D, n_classes=N_CLASSES).to(device)
dummy_img = torch.randn(B, 3, 32, 32)
class_names = [f"species_{i}" for i in range(N_CLASSES)]
vlm_out = vlm(dummy_img, class_names)
assert vlm_out["z_v"].shape == (B, D)
assert vlm_out["t_c"].shape == (N_CLASSES, D)
stage1 = Stage1Loss(temperature=2.0)
logits_mock = torch.randn(B, N_CLASSES)
labels_mock = torch.randint(0, N_CLASSES, (B,))
s1_out = stage1(vlm_out["z_v"], vlm_out["z_v"], vlm_out["t_c"], logits_mock, labels_mock, lam=0.3)
assert torch.isfinite(s1_out["total"])
print(f" L_VLM2Inter = {s1_out['total'].item():.4f} ✓ (sia={s1_out['sia'].item():.3f}, ira={s1_out['ira'].item():.3f})")
# ── 6. Intermediate Teacher + Student architectures ───────────────────────
print("\n[6/7] Teacher and Student architecture checks...")
teacher = IntermediateTeacher(n_classes=N_CLASSES, feature_dim=D).to(device)
t_out = teacher(dummy_img)
assert t_out["z_t_embed"].shape == (B, D)
assert t_out["logits"].shape == (B, N_CLASSES)
n_teacher = sum(p.numel() for p in teacher.parameters())
print(f" Intermediate Teacher params: {n_teacher:,} | logits: {tuple(t_out['logits'].shape)} ✓")
for arch in ["resnet18"]:
stu = LightweightStudent(arch=arch, n_classes=N_CLASSES, feature_dim=D, pretrained=False).to(device)
s_out = stu(dummy_img)
assert s_out["logits"].shape == (B, N_CLASSES)
n_stu = sum(p.numel() for p in stu.parameters())
print(f" Student [{arch}] params: {n_stu:,} | logits: {tuple(s_out['logits'].shape)} ✓")
# ── 7. Full Training Run ─────────────────────────────────────────────────
print("\n[7/7] Full two-stage training run (3 + 3 epochs, mock data)...")
cfg = DAITConfig(
n_classes=N_CLASSES,
feature_dim=D,
stage1_epochs=3,
stage2_epochs=3,
batch_size=4,
lr=1e-3,
lambda_k=0.5,
student_arch="resnet18",
)
trainer = DAITTrainer(cfg, device)
dataset = FGVCDataset(dataset="aircraft", n_samples=16, n_classes=N_CLASSES)
loader = DataLoader(dataset, batch_size=4, shuffle=True)
trainer.train(loader, val_loader=loader, class_names=dataset.get_class_names())
print("\n" + "=" * 60)
print("✓ All DAIT checks passed. Framework is ready for use.")
print("=" * 60)
print("""
Next steps to reproduce paper results:
1. Install OpenCLIP and load the real VLM:
pip install open-clip-torch
import open_clip
model, _, preprocess = open_clip.create_model_and_transforms(
'convnext_xxlarge', pretrained='laion2b_s34b_b82k_augreg_soup'
)
Then update FrozenVLM.encode_image() and encode_text() to use real model.
2. Download fine-grained datasets:
CUB-200: http://www.vision.caltech.edu/datasets/cub_200_2011/
Aircraft: https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/
Sf Cars: http://ai.stanford.edu/~jkrause/cars/car_dataset.html
Sf Dogs: http://vision.stanford.edu/aditya86/ImageNetDogs/
NABirds: https://dl.allaboutbirds.org/nabirds
3. Use full paper hyperparameters:
n_classes = {200, 100, 196, 120, 555} per dataset
feature_dim = 512
temperature = 2.0
lambda_k = 0.5, lambda_b = 0.0 (best from Fig. 7a)
batch_size = 32
lr = 1e-4, weight_decay = 1e-4
stage1_epochs = stage2_epochs = 100
img_size = 224
4. Expected results (vs. VL2Lite baseline):
ResNet-18 on Aircraft: 67.44% (+11.57% over VL2Lite)
ResNet-18 on CUB-200: 79.77% (+8.39% over VL2Lite)
EfficientNet-B0 on Aircraft: 70.99% (+18.81% over VL2Lite)
ShuffleNet-V2 on Stanford Cars: 90.02% (+11.91% over VL2Lite)
5. For DAIT-L (logit distillation variant):
Replace L_SRA with KL divergence on teacher/student output softmax.
""")
Read the Full Paper
The complete study — including CKA feature similarity analysis, t-SNE feature distribution visualizations, full ablation results on λ schedule configurations, and detailed per-dataset comparison tables — is available on arXiv.
He, Z., Li, J., & Wu, Z. (2026). DAIT: Distillation from Vision-Language Models to Lightweight Classifiers with Adaptive Intermediate Teacher Transfer. arXiv:2603.15166v1 [cs.CV]. Nanjing Normal University & Westlake University.
This article is an independent editorial analysis of peer-reviewed research. The Python implementation is an educational adaptation. The original authors used ConvNeXt-XXLarge from OpenCLIP and RegNet-Y-1.6GF as the intermediate teacher, trained on a single NVIDIA 3090 GPU for 100 epochs per stage. Refer to the paper for exact hyperparameter ablation details.
Explore More on AI Trend Blend
If this article sparked your interest, here is more of what we cover — from model compression and VLM research to fine-grained recognition, continual learning, and efficient deployment.
