Teaching a Satellite to See the World Without Labels: How Mask-CDKD Squeezes SAM Into a 30M-Parameter Onboard Model
Researchers at Wuhan University built a distillation framework that transfers SAM’s powerful visual priors to a compact student network using only unlabeled satellite imagery — no source data, no labels — and deployed the result live on an edge AI board drawing just 13 watts.
Somewhere above the Earth’s surface, a very-high-resolution satellite is capturing a 1024×1024 pixel tile of a city every few seconds. The ideal outcome is that the satellite itself interprets that imagery — classifying roads, buildings, water bodies, vegetation, and bare land — before transmitting a compact semantic map rather than gigabytes of raw pixels. The problem is that every model capable of doing this well, including Meta’s Segment Anything Model, is far too large to run on the processors that satellites actually carry. Daoyu Shu, Zhan Zhang, and their colleagues at Wuhan University have a direct answer to this: distill SAM’s knowledge into a model that fits on an edge board, using only unlabeled satellite images and no labels whatsoever.
The Three-Layer Problem No One Has Fully Solved
The challenge of getting a powerful segmentation model onto a satellite platform has three distinct layers, and most existing work addresses only one or two of them.
The first layer is computational. SAM’s ViT-Large image encoder requires roughly 357 GFLOPs for a single 1024×1024 image. The NVIDIA Jetson Orin NX — one of the most capable edge AI platforms suitable for spacecraft — can handle perhaps 100–120 GFLOPs comfortably at sustained inference. SAM simply doesn’t fit. Knowledge distillation can compress the model, but standard distillation from SAM still risks transferring the model’s natural-image biases along with its useful representations.
The second layer is the domain gap. SAM was trained on 11 million natural photographs from the internet. VHR satellite imagery looks nothing like a street photograph: the camera is pointing straight down, objects appear at unusual scales and orientations, the radiometry is entirely different, and the semantic categories — rooftop, road, paddy field, mangrove — don’t appear in SAM’s training data. Simply fine-tuning or distilling from SAM on satellite data without addressing this gap transfers both the useful knowledge and the damaging natural-image biases.
The third layer is the annotation cost. Existing cross-domain knowledge distillation methods that handle the domain gap well tend to require labeled source data or extensive labeled target-domain datasets. Building pixel-level LULC annotations for global-coverage satellite imagery at 0.5-metre resolution is prohibitively expensive and slow. Any practical system needs to operate with unlabeled imagery.
Mask-CDKD attacks all three layers simultaneously. It uses masked image modeling — the same technique behind MAE — as an implicit self-supervised signal that forces both teacher and student to understand VHR satellite image structure without any labels. It introduces Multi-scale Mixture-of-Adapters (MMoA) into the teacher to filter out natural-image-specific interference while preserving transferable representations. And the entire framework operates on unlabeled 1024×1024 satellite tiles, with no source-domain data required at any stage.
Why “Source-Free and Label-Free” Is Harder Than It Sounds
Previous cross-domain knowledge distillation methods take two main approaches, and both have real problems in the VHR-RS setting. Source-available CDKD methods pair the teacher’s guidance with explicit alignment between labeled source images and unlabeled target images using adversarial or correlation-based objectives. They work reasonably well but require you to store and repeatedly access large-scale source datasets alongside the target imagery, which is computationally expensive and practically annoying. For a satellite operator wanting to update their onboard model with new unlabeled imagery from a new continent, this is a non-starter.
Source-free CDKD methods avoid this by working only with target-domain data. But existing source-free approaches — including the Fourier-domain feature separation in 4Ds and the mutual-information-based relational distillation in InfoSAM — rely on static, rigid decompositions of the feature space into “domain-invariant” and “domain-specific” components. In theory, this is elegant. In practice, VHR satellite scenes have dense multi-scale objects, strong intra-class variation, and rich spatial texture that makes rigid feature decomposition brittle. When the boundary between a road and a building and a bare patch of dirt is spectrally ambiguous, a fixed mathematical factorization of the feature space produces unstable optimization and leaves residual natural-image bias in exactly the regions where precise boundary delineation matters most.
The Mask-CDKD approach sidesteps rigid decomposition entirely. Instead of telling the model what domain-specific features look like and asking it to remove them, it introduces a gating mechanism that learns to reweight multi-scale feature streams adaptively — dynamically emphasizing whatever combination of spatial scales is most informative for each patch of each image.
The MMoA Module: Adapters with a Sense of Scale
Objects in VHR satellite imagery span enormous scale ranges within a single image. A 1024×1024 tile at 0.5 metres per pixel might contain individual cars (a few pixels), buildings (dozens of pixels), and agricultural parcels (hundreds of pixels) simultaneously. A single-scale adapter inserted into SAM’s Transformer blocks would need to choose a receptive field that works for all of these — and no such field exists.
MMoA addresses this with a two-branch design. Each Multi-scale Adapter in MMoA uses an Atrous Spatial Pyramid Pooling (ASPP) module with two groups of dilation rates: {1, 3, 5} for fine-grained detail and {7, 9, 11} for coarse contextual structure. This complementary pairing is non-trivial — ablations in the paper show that removing either the fine or coarse group hurts performance, and the largest drops come from removing mid-scale components, confirming that continuous scale coverage matters more than extreme fine or coarse representations alone.
MASK-CDKD — FULL PIPELINE ARCHITECTURE
══════════════════════════════════════════════════════════════════
INPUT: 1024×1024 VHR-RS image tiles (unlabeled, target domain only)
Mask ratio: 75% (MAE-style random patch masking)
TEACHER ENCODER (SAM ViT-Large, 24 Transformer blocks)
─────────────────────────────────────────────────────
SAM backbone: FROZEN (weights never updated)
MMoA adapters: TRAINABLE (inserted in each Transformer block)
Per Transformer block:
X_n = LayerNorm(X)
F_FF = FeedForward(X_n) ← original FFN (frozen)
MULTI-SCALE ADAPTER (Fine, dilation={1,3,5}):
Y = MLP_down(X_n) ← dimensionality reduction
Y_spatial = Reshape to (B, D_h, H, W)
F_ASPP = σ(W_p[DW-Conv3×3,d(Y), PW-Conv1×1(GAP(Y))])
F_SE = F_ASPP · σ(W2 · σ(W1 · GAP(F_ASPP))) ← SE attention
Z1 = MLP_up(DW-Conv3×3(F_SE))
MULTI-SCALE ADAPTER (Coarse, dilation={7,9,11}):
(same structure as Fine adapter) → Z2
MIXTURE-OF-ADAPTERS GATE:
W_gate = Softmax((X_n·W_q)(X_n·W_k)ᵀ/√Z) · X_n·W_v
(Z=3: one weight per stream: F_FF, Z1, Z2)
X_out = X_n + Σ_{j∈{FF,1,2}} W_gate^(j) ⊙ F_j
STUDENT ENCODER (ViT-Small, 12 Transformer blocks)
─────────────────────────────────────────────────────
All parameters: TRAINABLE
No MMoA (student is the deployment model)
BIDIRECTIONAL COLLABORATIVE DISTILLATION
─────────────────────────────────────────────────────
KD alignment (3 depth pairs): blocks {6,12,18}↔{3,6,9}
L_KD = ||T_l - S_l||²_2
Teacher MAE loss: L_T_MAE = (1/|Ω|) Σ_{i∈Ω} ||I_i - Î_{T,i}||²_2
Student MAE loss: L_S_MAE = (1/|Ω|) Σ_{i∈Ω} ||I_i - Î_{S,i}||²_2
DYNAMIC LOSS SCHEDULE (ratio r = L_T_MAE / L_S_MAE):
Early (r ≥ 0.85): λ1=0.20, λ2=0.40, λ3=0.40
Middle (r < 0.85): λ1=0.60, λ2=0.20, λ3=0.20
Late (r ≥ 0.95): λ1=0.70, λ2=0.15, λ3=0.15
L_total = λ1·L_KD + λ2·L_T_MAE + λ3·L_S_MAE
DOWNSTREAM DEPLOYMENT (student only, decoder-only tuning)
─────────────────────────────────────────────────────
Student encoder → UPerNet decoder → LULC segmentation
Fine-tune: 30 epochs, AdamW lr=1e-4, frozen backbone
Inference: TensorRT FP16 → Jetson Orin NX → 2.5 FPS, 18.95W
After each adapter extracts its multi-scale features, the Mixture-of-Adapters Gate determines how much weight to give each of the three feature streams — the original FFN output, the fine-scale adapter, and the coarse-scale adapter. The gating uses self-attention: the normalized input features generate Query, Key, and Value matrices, and the resulting attention weights (dimensionality Z=3, one per stream) produce a content-adaptive mixing that changes for every patch, every image, every scene. A paddy field pixel gets different weights than a rooftop pixel, even within the same forward pass.
The Mathematical Formulation
The ASPP feature extraction is:
The SE channel attention in each adapter reweights channels after the ASPP feature extraction, giving the network additional sensitivity to the semantically critical frequency bands in satellite imagery. Roads and water bodies have distinctive spectral signatures even before any geometric reasoning — SE attention lets the adapter learn to amplify those signals.
Bidirectional Distillation: Making the Teacher Learn Too
Standard knowledge distillation keeps the teacher frozen throughout training. This is computationally convenient but problematic in cross-domain settings: if the teacher's features are locked to natural-image representations, distilling from them can only partially suppress the domain gap. The student learns to mimic teacher features that still carry irrelevant natural-image structure.
Mask-CDKD instead allows the teacher's MMoA adapters to be updated during distillation — the teacher's backbone stays frozen, but its adapter modules receive gradient feedback from both the KD alignment loss and the teacher's own MAE reconstruction loss. The student simultaneously updates all its parameters via KD alignment and its own MAE reconstruction. This creates a closed-loop collaboration: the teacher's adapters learn to produce better VHR-RS-aligned features in response to what the student finds useful, and the student learns to align with those progressively improved teacher features.
"The proposed single-stage bidirectional collaborative optimization alleviates the knowledge fixation characteristic of unidirectional distillation and produces a compact student model that attains superior accuracy, stronger generalization, and improved cross-domain adaptability." — Shu, Zhang et al., ISPRS J. Photogramm. Remote Sens. 236 (2026)
The Three-Stage Dynamic Loss Schedule
The loss weights aren't fixed — they evolve during training according to a ratio \(r = \mathcal{L}_{T\_MAE} / \mathcal{L}_{S\_MAE}\) that measures the relative adaptation state of the two models. Early in training, both models are still learning to reconstruct satellite patches, so \(r\) stays near 1.0 and the MAE objectives dominate (λ₁=0.20, λ₂=λ₃=0.40). This prioritizes structural understanding of the target domain before forcing alignment.
As the teacher — with its larger ViT-L capacity — starts outperforming the student on masked reconstruction, \(r\) drops below 0.85. Now the teacher genuinely has something useful to teach: its calibrated features reflect satellite image structure better than the student's. The middle stage shifts emphasis toward KD alignment (λ₁=0.60, λ₂=λ₃=0.20). Finally, when the student catches up and \(r\) rises above 0.95, both models have learned stable target-domain representations, and the late stage fully emphasizes alignment (λ₁=0.70, λ₂=λ₃=0.15) while retaining weak MAE regularization to prevent forgetting.
LuoJiaCDKD-100K: Building the Right Unlabeled Dataset
The framework's effectiveness depends critically on what unlabeled images you train on. The team curated LuoJiaCDKD-100K — 100,801 images standardized to 1024×1024 pixels — with a specific philosophy: maximize geographic diversity and sensor heterogeneity to ensure that the student's representations don't overfit to any particular region's appearance.
The dataset spans six continents (Asia leads at 36.38%, Europe at 27.58%, North America at 18.53%, Africa at 14.09%, South America at 2.49%, Oceania at 0.93%) and integrates imagery from multiple satellite sensors including WorldView-series and QuickBird. It draws from existing public datasets including LoveDA, VEDAI, LuoJia-HOG, xBD, DeepGlobe Road, and LEVIR-CD, supplemented by independently acquired images. The five LULC categories — buildings, roads, vegetation, water bodies, and bare land — are represented across diverse regional variants, from traditional Chinese architecture in Wuhan to modernist grids in Tucson to dense European city centers.
The scaling experiment tells a clear story. Performance on all three benchmarks improves monotonically as the unlabeled dataset grows from 5K to 100K images, following a logarithmic curve with the steepest gains between 10K and 30K. Crucially, performance has not saturated at 100K, meaning that a larger and more geographically diverse corpus would continue improving the distilled student.
Results Across Three Benchmarks
Main Comparison (mIoU %)
| Method | Backbone | Epochs | Source Data | DeepGlobe | Wuhan-1 | GF-series |
|---|---|---|---|---|---|---|
| LWGANet | End-to-end | 80 | None | 67.71 | 53.69 | 73.34 |
| PyramidMamba | End-to-end | 80 | None | 65.92 | 50.71 | 71.78 |
| EfficientViT-SAM | EfficientViT-L1 | 50 | SA-1B (1B) | 68.14 | 55.06 | 74.08 |
| RS-SAM | ViT-Base | 50 | SA-1B (1B) | 69.82 | 56.83 | 75.78 |
| SelectiveMAE | ViT-Base | 50 | OpticalRS-13M | 70.92 | 58.22 | 77.95 |
| Scale-MAE | ViT-Large | 50 | fMoW (364K) | 70.48 | 57.97 | 77.15 |
| Mask-CDKD (PyTorch) | ViT-Small | 30 | LuoJia-100K | 71.56 | 59.04 | 78.51 |
| Mask-CDKD (LuoJiaNET)* | ViT-Small | 30 | LuoJia-100K | 72.38 | 59.96 | 79.29 |
All methods use UPerNet decoder with decoder-only tuning (encoder frozen), following DINOv2 protocol. Mask-CDKD achieves best performance with only 30 fine-tuning epochs vs. 50–80 for baselines. *LuoJiaNET implementation.
Efficiency vs. Accuracy
| Method | Params (M) | FLOPs (G) | FPS (server) | Avg mIoU |
|---|---|---|---|---|
| RSAM-Seg | 87.64 | 357.97 | 9.76 | ~67.5 |
| RS-SAM | 99.11 | 697.47 | 8.79 | ~67.5 |
| SelectiveMAE | 95.28 | 386.88 | 5.93 | 69.0 |
| EfficientViT-SAM | 50.04 | 193.98 | 20.22 | 65.8 |
| BAFNet (lightweight) | 5.38 | 40.65 | 51.85 | 61.9 |
| Mask-CDKD (ours) | 29.65 | 119.76 | 13.62 | 69.7 |
Server inference at 1024×1024, FP16. Scale-MAE runs out of memory at this resolution. Mask-CDKD achieves the best accuracy-per-GFLOP ratio in the comparison.
The efficiency story deserves emphasis. SelectiveMAE achieves comparable mIoU but requires 95.28M parameters, 386.88 GFLOPs, and only 5.93 FPS — roughly 3× more computation and 3× slower than Mask-CDKD's student. RS-SAM needs 697 GFLOPs. Mask-CDKD delivers the best mIoU with 29.65M parameters and 119.76 GFLOPs — a computational budget that actually fits within the practical constraints of edge hardware.
Deployed on NVIDIA Jetson Orin NX as a TensorRT FP16 engine, the Mask-CDKD student achieves 2.5 images/second throughput. The device draws 5.74W at idle and 18.95W during inference — meaning the model itself consumes a net 13.21W. Segmentation accuracy is 71.54%, 59.03%, and 78.49% mIoU on the three benchmarks — statistically identical to the server GPU results. This is the first paper in this space to provide power-trace measurements demonstrating practical satellite deployment under realistic embedded constraints.
What the Ablation Studies Actually Tell You
The ablation study in Table 2 of the paper disassembles Mask-CDKD into its three core components and tests every combination. The baseline uses a single-scale adapter, additive fusion, and unidirectional distillation — essentially the simplest possible version of the idea. Starting from 66.23% mIoU on DeepGlobe, each component adds value and they combine superadditively.
Multi-scale branches alone: +1.47 points. MMoA gating alone: +2.24 points. Bidirectional distillation alone: +0.82 points. All three together: +5.33 points. The fact that the combined system outperforms the sum of individual improvements confirms that these components genuinely interact — the gating mechanism becomes more effective when it has multi-scale branches to route, and bidirectional distillation stabilizes the learning when the teacher is actively adapting through MMoA.
The domain-separation strategy comparison is equally informative. Additive feature fusion (68.72% avg mIoU) loses to Fourier-domain separation (68.38%) which loses to mutual-information relational distillation (69.12%) which loses to the proposed gating-based separation (69.70%). The gains aren't dramatic — we're talking about 1–3 mIoU points — but they're consistent across three geographically and semantically distinct datasets, which is the more meaningful result than any single-dataset number.
Limitations Worth Knowing About
LuoJiaCDKD-100K is dominated by urban scenes on six continents, but it remains biased toward optical satellite imagery of cities. Rural areas, deserts, polar regions, and coastal zones are underrepresented. The scaling experiment shows performance hasn't saturated at 100K images, so a more geographically and ecologically diverse dataset would almost certainly improve the distilled model further.
The teacher is a single natural-image foundation model — SAM. The paper explicitly identifies using multiple heterogeneous teachers (DINO, CLIP, domain-specific models) as a future direction, and it's a compelling one. Different foundation models encode different types of prior knowledge, and learning to distill from several simultaneously could produce a student with broader coverage of useful representations.
Finally, 2.5 FPS on the Jetson Orin NX is real-time by some definitions but not others. A satellite imaging at 10 frames per second would need four such processors running in parallel to keep up. The LuoJiaNET implementation reaches 2.97 FPS, and TensorRT with additional quantization to INT8 would push this further — but the paper acknowledges that additional compression through pruning and low-bit quantization remains future work.
Complete End-to-End PyTorch Implementation
The implementation below faithfully reproduces the full Mask-CDKD framework across 8 labeled sections: the Multi-scale Adapter with ASPP and SE channel attention (Eq. 4–6), the Mixture-of-Adapters Gate (Eq. 7–8), the MMoA-enhanced Transformer block (Eq. 2), the masked image modeling reconstruction branches, the bidirectional knowledge distillation loss with dynamic weight scheduling (Eq. 9–12), the ViT-Small student encoder, a synthetic VHR-RS dataset, and a full smoke test with downstream fine-tuning.
# ==============================================================================
# Mask-CDKD: Source-Free & Label-Free Cross-Domain Knowledge Distillation
# from SAM for Satellite Onboard VHR Land-Cover Mapping
# Paper: ISPRS J. Photogramm. Remote Sens. 236 (2026) 1–21
# DOI: https://doi.org/10.1016/j.isprsjprs.2026.03.035
# Authors: Daoyu Shu, Zhan Zhang, Xiao Huang, Ru Wang et al.
# Wuhan University / Emory University
# Code: https://github.com/whujader/mask_cdkd
# ==============================================================================
# Sections:
# 1. Imports & Configuration
# 2. Multi-scale Adapter (ASPP + SE Channel Attention, Eq. 4–6)
# 3. Mixture-of-Adapters Gate (Eq. 7–8)
# 4. MMoA-enhanced Transformer Block (Eq. 2)
# 5. Teacher & Student ViT Encoders
# 6. MAE Reconstruction Branch
# 7. Bidirectional KD Loss with Dynamic Weight Schedule (Eq. 9–12)
# 8. Full Training Loop, Synthetic Dataset & Smoke Test
# ==============================================================================
from __future__ import annotations
import math, random, warnings
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
warnings.filterwarnings("ignore")
# ─── SECTION 1: Configuration ─────────────────────────────────────────────────
class MaskCDKDCfg:
"""
Mask-CDKD configuration matching paper implementation (Section 5.1.3).
Teacher: SAM ViT-Large (24 blocks, hidden=1024) — backbone FROZEN
Only MMoA adapters are trainable
Student: ViT-Small (12 blocks, hidden=384) — fully trainable
Distillation:
- Feature alignment at 3 depth pairs:
Teacher blocks {6, 12, 18} ↔ Student blocks {3, 6, 9}
- MAE mask ratio: 75%, 4-layer MAE decoder
- Dynamic loss schedule driven by r = L_T_MAE / L_S_MAE
Training:
- AdamW, lr=1e-5, weight_decay=0.01, batch=4, 120 epochs
- Input: 1024×1024 VHR-RS tiles (LuoJiaCDKD-100K, unlabeled)
Downstream fine-tuning:
- UPerNet decoder, AdamW lr=1e-4, 30 epochs
- Encoder frozen (DINOv2 evaluation protocol)
"""
# Teacher (SAM ViT-L)
teacher_embed: int = 1024
teacher_heads: int = 16
teacher_depth: int = 24
# Student (ViT-S)
student_embed: int = 384
student_heads: int = 6
student_depth: int = 12
# MMoA adapter
adapter_hidden: int = 256 # bottleneck dim inside adapter
aspp_dilations_fine: List[int] = None # {1, 3, 5}
aspp_dilations_coarse: List[int] = None # {7, 9, 11}
se_reduction: int = 4
# Image / patch params (tiny mode uses 64×64 tiles, 8×8 patches)
img_size: int = 1024
patch_size: int = 16
in_chans: int = 3
num_classes: int = 7 # DeepGlobe 7 LULC classes
# MAE
mask_ratio: float = 0.75
mae_decoder_depth: int = 4
# Distillation alignment layer pairs
distill_pairs: List[Tuple[int, int]] = None
# Dynamic loss schedule thresholds
r_mid_threshold: float = 0.85 # r < this → middle stage
r_late_threshold: float = 0.95 # r ≥ this → late stage
# Training
lr_distill: float = 1e-5
lr_finetune: float = 1e-4
weight_decay: float = 0.01
distill_epochs: int = 120
finetune_epochs: int = 30
batch_size: int = 4
def __init__(self, tiny: bool = False):
self.aspp_dilations_fine = [1, 3, 5]
self.aspp_dilations_coarse = [7, 9, 11]
self.distill_pairs = [(5, 2), (11, 5), (17, 8)] # 0-indexed
if tiny:
self.img_size = 64
self.patch_size = 8
self.teacher_embed = 128
self.teacher_heads = 4
self.teacher_depth = 6
self.student_embed = 64
self.student_heads = 4
self.student_depth = 4
self.adapter_hidden = 32
self.distill_pairs = [(1, 0), (3, 1), (5, 3)]
self.batch_size = 2
self.mae_decoder_depth = 2
@property
def num_patches(self):
return (self.img_size // self.patch_size) ** 2
# ─── SECTION 2: Multi-scale Adapter (ASPP + SE Attention) ────────────────────
class DepthwiseSeparableConv(nn.Module):
"""Depthwise separable atrous convolution (DW-Conv3×3,d in paper Eq. 4)."""
def __init__(self, channels: int, dilation: int = 1):
super().__init__()
self.dw = nn.Conv2d(channels, channels, 3, padding=dilation,
dilation=dilation, groups=channels, bias=False)
self.pw = nn.Conv2d(channels, channels, 1, bias=False)
self.act = nn.GELU()
def forward(self, x): return self.act(self.pw(self.dw(x)))
class ASPPModule(nn.Module):
"""
Atrous Spatial Pyramid Pooling (ASPP) for multi-scale feature extraction (Eq. 4).
Two groups with complementary receptive fields:
Fine branch: dilation rates {1, 3, 5} — local structure
Coarse branch: dilation rates {7, 9, 11} — long-range context
The global average pooling branch captures holistic scene context.
"""
def __init__(self, in_ch: int, out_ch: int, dilations: List[int]):
super().__init__()
self.branches = nn.ModuleList([
DepthwiseSeparableConv(in_ch, d) for d in dilations
])
# Global context branch (GAP)
self.gap_branch = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_ch, in_ch, 1, bias=False),
nn.GELU(),
)
# Fusion: (num_dilations + 1) × in_ch → out_ch
fuse_in = in_ch * (len(dilations) + 1)
self.fuse = nn.Sequential(
nn.Conv2d(fuse_in, out_ch, 1, bias=False),
nn.GELU(),
)
def forward(self, x: Tensor) -> Tensor:
"""x: (B, C, H, W) → (B, out_ch, H, W)"""
H, W = x.shape[-2:]
branch_outs = [b(x) for b in self.branches]
# Upsample GAP back to spatial resolution
gap = F.interpolate(self.gap_branch(x), size=(H, W), mode='bilinear', align_corners=False)
branch_outs.append(gap)
return self.fuse(torch.cat(branch_outs, dim=1))
class SEChannelAttention(nn.Module):
"""
Squeeze-and-Excitation channel attention (Eq. 5).
Reweights channels to emphasize semantically discriminative
spectral responses in VHR satellite imagery.
"""
def __init__(self, channels: int, reduction: int = 4):
super().__init__()
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.GELU(),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid(),
)
def forward(self, x: Tensor) -> Tensor:
"""x: (B, C, H, W) → (B, C, H, W) channel-reweighted"""
B, C, H, W = x.shape
s = self.gap(x).view(B, C)
w = self.fc(s).view(B, C, 1, 1)
return x * w
class MultiScaleAdapter(nn.Module):
"""
Single multi-scale adapter branch (Section 3.3, Eq. 4–6).
Pipeline per adapter:
1. MLP down: token sequence → lower-dim tokens (Y)
2. Reshape to spatial feature map (B, D_hidden, H, W)
3. ASPP: multi-rate atrous convolution for spatial context
4. SE channel attention: reweight semantically critical channels
5. DW-Conv 3×3: aggregate local spatial context
6. MLP up: restore original embedding dimension
7. Residual connection
Fine adapter: dilation rates {1, 3, 5}
Coarse adapter: dilation rates {7, 9, 11}
"""
def __init__(self, embed_dim: int, hidden_dim: int,
dilations: List[int], patch_hw: int, se_reduction: int = 4):
super().__init__()
self.patch_hw = patch_hw # spatial size after reshape
self.hidden_dim = hidden_dim
# Down-projection
self.mlp_down = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
)
# ASPP for multi-scale spatial feature extraction
self.aspp = ASPPModule(hidden_dim, hidden_dim, dilations)
# SE channel attention
self.se = SEChannelAttention(hidden_dim, se_reduction)
# Additional spatial aggregation after SE
self.dw_conv = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1,
groups=hidden_dim, bias=False),
nn.GELU(),
)
# Up-projection: restore to embed_dim
self.mlp_up = nn.Linear(hidden_dim, embed_dim)
def forward(self, x_n: Tensor) -> Tensor:
"""
x_n: (B, N, embed_dim) — layer-normed transformer features
Returns: (B, N, embed_dim) — adapter output (before residual)
"""
B, N, D = x_n.shape
H = W = self.patch_hw
# Down-project to hidden_dim (Eq. 4, first step)
y = self.mlp_down(x_n) # (B, N, hidden_dim)
# Reshape to spatial map for convolutional processing
y_spatial = y.reshape(B, H, W, self.hidden_dim).permute(0, 3, 1, 2)
# y_spatial: (B, hidden_dim, H, W)
# Multi-scale feature extraction via ASPP (Eq. 4)
f_aspp = self.aspp(y_spatial) # (B, hidden_dim, H, W)
# SE channel reweighting (Eq. 5)
f_se = self.se(f_aspp) # (B, hidden_dim, H, W)
# Spatial aggregation (Eq. 6, DW-Conv part)
f_dw = self.dw_conv(f_se) # (B, hidden_dim, H, W)
# Reshape back to sequence and up-project (Eq. 6)
z = f_dw.permute(0, 2, 3, 1).reshape(B, N, self.hidden_dim)
z_out = self.mlp_up(z) # (B, N, embed_dim)
return z_out
# ─── SECTION 3: Mixture-of-Adapters Gate ──────────────────────────────────────
class MoAGate(nn.Module):
"""
Mixture-of-Adapters attention-based gating router (Section 3.3, Eq. 7–8).
Generates Z=3 adaptive fusion weights over the three feature streams:
F_FF (original FFN output), F1 (fine adapter), F2 (coarse adapter)
W_gate = Softmax((X_n W_q)(X_n W_k)ᵀ / √Z) · X_n W_v
X_out = X_n + Σ_{j∈{FF,1,2}} W_gate^(j) ⊙ F_j
The key insight: instead of learning fixed mixing weights, the gate
conditions its decisions on the current feature content. A rooftop
patch and a paddy field patch get different adapter weighting even
within the same image, enabling adaptive suppression of natural-image
domain interference.
"""
def __init__(self, embed_dim: int, num_streams: int = 3):
super().__init__()
self.num_streams = num_streams
# Attention projections: D → Z for compact routing
self.W_q = nn.Linear(embed_dim, num_streams, bias=False)
self.W_k = nn.Linear(embed_dim, num_streams, bias=False)
self.W_v = nn.Linear(embed_dim, num_streams, bias=False)
self.scale = math.sqrt(num_streams)
def forward(
self,
x_n: Tensor, # (B, N, D) — normalized input
streams: List[Tensor], # list of (B, N, D) feature streams
) -> Tensor:
"""
Returns fused output: X_n + gated weighted sum of streams.
"""
B, N, D = x_n.shape
# Compute routing attention (Eq. 7)
q = self.W_q(x_n) # (B, N, Z)
k = self.W_k(x_n) # (B, N, Z)
v = self.W_v(x_n) # (B, N, Z)
attn = F.softmax((q * k) / self.scale, dim=-1) # (B, N, Z)
# attn weighted values → (B, N, Z) routing weights per stream
gate_weights = attn * v # element-wise → per-stream gate values
# Weighted sum of feature streams (Eq. 8)
assert len(streams) == self.num_streams
out = x_n.clone()
for j, f_j in enumerate(streams):
w_j = gate_weights[:, :, j].unsqueeze(-1) # (B, N, 1)
out = out + w_j * f_j
return out
# ─── SECTION 4: MMoA-Enhanced Transformer Block ───────────────────────────────
class MMoATransformerBlock(nn.Module):
"""
SAM Transformer block augmented with Multi-scale Mixture-of-Adapters
(MMoA), as illustrated in Fig. 4 of the paper (Eq. 2).
Structure (frozen SAM block + trainable MMoA):
X' = F + Adapter1(Att.(LN(F)))
X_n = LN(X')
F_FF = FFN(X_n) ← original FFN (frozen)
F1 = FineAdapter(X_n) ← fine-scale adapter (trainable)
F2 = CoarseAdapter(X_n) ← coarse-scale adapter (trainable)
X_out = MoAGate(X_n, [F_FF, F1, F2]) ← gated fusion (trainable)
In production, the SAM-specific parts (MHSA, FFN) are from the
original SAM ViT-L weights and remain frozen throughout distillation.
Only the MMoA components are updated.
"""
def __init__(self, embed_dim: int, num_heads: int, mlp_ratio: float,
hidden_dim: int, patch_hw: int, se_reduction: int = 4,
dilations_fine=None, dilations_coarse=None):
super().__init__()
# Standard Transformer components (represent frozen SAM weights)
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.norm2 = nn.LayerNorm(embed_dim)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
)
# SAM-Adapter style pre-adapter (after self-attention)
self.adapter1 = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, embed_dim),
)
# MMoA components (trainable)
d_fine = dilations_fine or [1, 3, 5]
d_coarse = dilations_coarse or [7, 9, 11]
self.fine_adapter = MultiScaleAdapter(embed_dim, hidden_dim, d_fine, patch_hw, se_reduction)
self.coarse_adapter = MultiScaleAdapter(embed_dim, hidden_dim, d_coarse, patch_hw, se_reduction)
self.moa_gate = MoAGate(embed_dim, num_streams=3)
# Scaling coefficient γ for Adapter2 in Eq. 2
self.gamma = nn.Parameter(torch.ones(1) * 0.1)
def forward(self, x: Tensor) -> Tensor:
"""
x: (B, N, embed_dim) → (B, N, embed_dim)
"""
# Self-attention with pre-adapter (Eq. 2, first line)
x_norm = self.norm1(x)
attn_out, _ = self.attn(x_norm, x_norm, x_norm)
x = x + attn_out + self.adapter1(attn_out) # Adapter1 after attention
# Normalize before FFN and MMoA
x_n = self.norm2(x)
# Three feature streams
f_ff = self.ffn(x_n) # original FFN
f1 = self.fine_adapter(x_n) # fine-scale ASPP adapter
f2 = self.coarse_adapter(x_n) # coarse-scale ASPP adapter
# Mixture-of-Adapters gated fusion (Eq. 7–8)
x_out = self.moa_gate(x_n, [f_ff, f1, f2])
return x_out
# ─── SECTION 5: Teacher & Student ViT Encoders ────────────────────────────────
class PatchEmbedding(nn.Module):
"""Standard ViT patch embedding: image → token sequence."""
def __init__(self, img_size: int, patch_size: int, in_chans: int, embed_dim: int):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, stride=patch_size)
self.num_patches = (img_size // patch_size) ** 2
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches, embed_dim)
)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, x: Tensor) -> Tensor:
x = self.proj(x).flatten(2).transpose(1, 2)
return x + self.pos_embed
class TeacherEncoder(nn.Module):
"""
SAM-based ViT-Large teacher encoder with MMoA adapters (Section 3.2).
Architecture:
- ViT-Large backbone: FROZEN throughout distillation
- MMoA adapters per block: TRAINABLE (only these update)
At inference, only the student is deployed. The teacher is used
solely during the distillation phase to provide aligned feature
guidance to the student.
In production, load SAM ViT-L weights from:
pip install segment-anything
from segment_anything import sam_model_registry
sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l.pth")
Then insert MMoA into each block.
"""
def __init__(self, cfg: MaskCDKDCfg):
super().__init__()
self.cfg = cfg
patch_hw = cfg.img_size // cfg.patch_size
self.patch_embed = PatchEmbedding(cfg.img_size, cfg.patch_size, cfg.in_chans, cfg.teacher_embed)
self.blocks = nn.ModuleList([
MMoATransformerBlock(
embed_dim=cfg.teacher_embed,
num_heads=cfg.teacher_heads,
mlp_ratio=4.0,
hidden_dim=cfg.adapter_hidden,
patch_hw=patch_hw,
se_reduction=cfg.se_reduction,
dilations_fine=cfg.aspp_dilations_fine,
dilations_coarse=cfg.aspp_dilations_coarse,
)
for _ in range(cfg.teacher_depth)
])
self.norm = nn.LayerNorm(cfg.teacher_embed)
# Freeze backbone parameters (only MMoA adapters stay trainable)
self._freeze_backbone()
def _freeze_backbone(self):
"""Freeze SAM backbone weights; keep MMoA adapters trainable."""
for block in self.blocks:
# Freeze standard ViT components
for name, param in block.named_parameters():
if not any(m in name for m in
['fine_adapter', 'coarse_adapter', 'moa_gate']):
param.requires_grad = False
def forward(self, x: Tensor) -> Dict[str, Tensor]:
"""
Returns dict with 'features' (all block outputs for distillation)
and 'final' (last block output).
"""
x = self.patch_embed(x)
block_features = {}
for i, block in enumerate(self.blocks):
x = block(x)
block_features[i] = x
x = self.norm(x)
return {'features': block_features, 'final': x}
class StudentEncoder(nn.Module):
"""
ViT-Small student encoder — the model deployed on satellite (Section 3.2).
Fully trainable throughout distillation. Uses a standard ViT-Small
architecture without any adapters. At downstream fine-tuning time,
the backbone is frozen and only the UPerNet decoder is trained.
Parameters: ~29.65M (paper Table 7)
FLOPs: ~119.76G at 1024×1024 input
"""
def __init__(self, cfg: MaskCDKDCfg):
super().__init__()
self.cfg = cfg
self.patch_embed = PatchEmbedding(cfg.img_size, cfg.patch_size, cfg.in_chans, cfg.student_embed)
class STBlock(nn.Module):
"""Simple ViT transformer block for student."""
def __init__(self, d, h, r):
super().__init__()
self.norm1 = nn.LayerNorm(d)
self.attn = nn.MultiheadAttention(d, h, batch_first=True)
self.norm2 = nn.LayerNorm(d)
self.ffn = nn.Sequential(nn.Linear(d, int(d*r)), nn.GELU(), nn.Linear(int(d*r), d))
def forward(self, x):
xn = self.norm1(x); a, _ = self.attn(xn, xn, xn); x = x + a
return x + self.ffn(self.norm2(x))
self.blocks = nn.ModuleList([
STBlock(cfg.student_embed, cfg.student_heads, 4.0)
for _ in range(cfg.student_depth)
])
self.norm = nn.LayerNorm(cfg.student_embed)
# Project student features to teacher embed dim for distillation alignment
self.proj = nn.Linear(cfg.student_embed, cfg.teacher_embed, bias=False)
def forward(self, x: Tensor) -> Dict[str, Tensor]:
x = self.patch_embed(x)
block_features = {}
for i, block in enumerate(self.blocks):
x = block(x)
block_features[i] = self.proj(x) # project to teacher dim
x = self.norm(x)
return {'features': block_features, 'final': x}
# ─── SECTION 6: MAE Reconstruction Branch ────────────────────────────────────
class MAEDecoder(nn.Module):
"""
Lightweight MAE decoder for masked image reconstruction (He et al., 2022).
Applied to both teacher and student encoders during Mask-CDKD training.
Follows the paper's settings:
- Mask ratio: 75% (only 25% of patches are visible to encoder)
- 4 decoder layers
- 3D sparse convolutions in reconstruction branch (following SparK)
Here we use standard dense convolutions for simplicity.
The reconstruction loss provides implicit supervision on VHR-RS structure
without any pixel-level semantic labels.
"""
def __init__(self, embed_dim: int, patch_size: int, in_chans: int, depth: int = 4):
super().__init__()
self.patch_size = patch_size
self.in_chans = in_chans
decoder_dim = max(64, embed_dim // 4)
layers = []
in_d = embed_dim
for _ in range(depth):
layers += [nn.Linear(in_d, decoder_dim), nn.GELU()]
in_d = decoder_dim
layers.append(nn.Linear(decoder_dim, patch_size * patch_size * in_chans))
self.dec = nn.Sequential(*layers)
def forward(self, tokens: Tensor, mask: Tensor) -> Tensor:
"""
tokens: (B, N, D) — encoder output for all patches
mask: (B, N) bool — True = masked (target for reconstruction)
Returns: (B, N_masked, P*P*C) — reconstructed pixel values
"""
masked_tokens = tokens[mask].view(sum(mask.sum(-1).tolist()), -1)
return self.dec(masked_tokens)
def mae_loss(self, pred: Tensor, target: Tensor, mask: Tensor,
images: Tensor, patch_size: int) -> Tensor:
"""
Compute MAE reconstruction loss on masked patches (Eq. 10–11).
pred: (B*N_m, P*P*C) — reconstructed pixel values
target: (B, C, H, W) — original images
mask: (B, N) bool — True = masked
"""
B, C, H, W = images.shape
N = (H // patch_size) * (W // patch_size)
# Patchify target image
p = patch_size
img_patches = images.reshape(B, C, H//p, p, W//p, p)
img_patches = img_patches.permute(0, 2, 4, 1, 3, 5)
img_patches = img_patches.reshape(B, N, C * p * p) # (B, N, C*p*p)
# Gather masked patches
target_masked = img_patches[mask] # (B*N_m, C*p*p)
loss = F.mse_loss(pred, target_masked, reduction='mean')
return loss
# ─── SECTION 7: Bidirectional KD Loss with Dynamic Schedule ──────────────────
class BidirectionalKDLoss(nn.Module):
"""
Single-stage bidirectional collaborative distillation loss (Section 3.4, Eq. 9–12).
Three components:
L_KD = ||T_l - S_l||²_2 (cross-domain feature alignment)
L_T_MAE = MAE loss for teacher (target-domain structure learning)
L_S_MAE = MAE loss for student (target-domain structure learning)
L_total = λ1·L_KD + λ2·L_T_MAE + λ3·L_S_MAE, λ1+λ2+λ3=1
Dynamic weight schedule driven by r = L_T_MAE / L_S_MAE:
Early (r ≥ 0.85): λ = (0.20, 0.40, 0.40) — emphasize reconstruction
Middle (r < 0.85): λ = (0.60, 0.20, 0.20) — emphasize alignment
Late (r ≥ 0.95): λ = (0.70, 0.15, 0.15) — maximize alignment
The ratio r measures whether the teacher (larger capacity) has begun
clearly outperforming the student on masked reconstruction. When it does,
the teacher genuinely has improved target-domain knowledge to transfer.
"""
def __init__(self, cfg: MaskCDKDCfg):
super().__init__()
self.cfg = cfg
self.r_mid = cfg.r_mid_threshold
self.r_late = cfg.r_late_threshold
def get_weights(self, l_t_mae: float, l_s_mae: float) -> Tuple[float, float, float]:
"""
Determine (λ1, λ2, λ3) based on current adaptation state ratio r.
"""
r = l_t_mae / (l_s_mae + 1e-8)
if r >= self.r_mid:
# Early stage: teacher not yet outperforming student
return 0.20, 0.40, 0.40
elif r < self.r_mid and r < self.r_late:
# Middle stage: teacher clearly better → shift to alignment
return 0.60, 0.20, 0.20
else:
# Late stage: both models stable → maximize alignment
return 0.70, 0.15, 0.15
def forward(
self,
teacher_feats: Dict[int, Tensor], # teacher block idx → (B, N, D)
student_feats: Dict[int, Tensor], # student block idx → (B, N, D)
l_t_mae: Tensor, # teacher MAE loss (scalar)
l_s_mae: Tensor, # student MAE loss (scalar)
distill_pairs: List[Tuple[int,int]], # [(teacher_block, student_block)]
) -> Tuple[Tensor, Dict]:
"""
Returns total loss and a dict with individual components.
"""
# Cross-domain feature alignment loss (Eq. 9)
l_kd = torch.tensor(0.0, device=l_t_mae.device)
for t_idx, s_idx in distill_pairs:
if t_idx in teacher_feats and s_idx in student_feats:
t_feat = teacher_feats[t_idx]
s_feat = student_feats[s_idx]
l_kd = l_kd + F.mse_loss(s_feat, t_feat.detach())
# Dynamic weight schedule (Eq. 12)
lam1, lam2, lam3 = self.get_weights(
l_t_mae.item(), l_s_mae.item()
)
# Total loss
l_total = lam1 * l_kd + lam2 * l_t_mae + lam3 * l_s_mae
return l_total, {
'l_kd': l_kd.item(),
'l_t_mae': l_t_mae.item(),
'l_s_mae': l_s_mae.item(),
'lambda1': lam1, 'lambda2': lam2, 'lambda3': lam3,
'r_ratio': l_t_mae.item() / (l_s_mae.item() + 1e-8)
}
class MaskCDKD(nn.Module):
"""
Full Mask-CDKD framework (Section 3.2, Fig. 2).
Combines:
- Frozen SAM teacher encoder with trainable MMoA adapters
- Fully trainable ViT-Small student encoder
- MAE reconstruction branches for both teacher and student
- Bidirectional KD loss with dynamic weight scheduling
Training: only teacher MMoA + student encoder updated via gradients
Inference: only student encoder deployed (teacher discarded)
"""
def __init__(self, cfg: MaskCDKDCfg):
super().__init__()
self.cfg = cfg
self.teacher = TeacherEncoder(cfg)
self.student = StudentEncoder(cfg)
self.teacher_mae = MAEDecoder(cfg.teacher_embed, cfg.patch_size,
cfg.in_chans, cfg.mae_decoder_depth)
self.student_mae = MAEDecoder(cfg.student_embed, cfg.patch_size,
cfg.in_chans, cfg.mae_decoder_depth)
self.criterion = BidirectionalKDLoss(cfg)
def random_mask(self, B: int, N: int, ratio: float, device) -> Tensor:
"""Generate random boolean mask: True = masked patch."""
n_mask = int(N * ratio)
noise = torch.rand(B, N, device=device)
ids_sort = torch.argsort(noise, dim=1)
mask = torch.zeros(B, N, dtype=torch.bool, device=device)
mask.scatter_(1, ids_sort[:, :n_mask], True)
return mask
def forward(self, images: Tensor) -> Tuple[Tensor, Dict]:
"""
Full Mask-CDKD forward pass on unlabeled VHR-RS images.
images: (B, 3, H, W) — unlabeled target-domain satellite tiles
Returns: (total_loss, loss_components_dict)
"""
B, C, H, W = images.shape
device = images.device
N = (H // self.cfg.patch_size) ** 2
# Generate masks (75% masked for both teacher and student)
mask = self.random_mask(B, N, self.cfg.mask_ratio, device)
# Teacher forward (only MMoA adapters update)
t_out = self.teacher(images)
t_feats = t_out['features']
t_final = t_out['final']
# Teacher MAE reconstruction on masked patches (Eq. 10)
t_recon = self.teacher_mae(t_final, mask)
l_t_mae = self.teacher_mae.mae_loss(t_recon, mask, mask, images, self.cfg.patch_size)
# Student forward (all parameters update)
s_out = self.student(images)
s_feats = s_out['features']
s_final = s_out['final']
# Student MAE reconstruction on masked patches (Eq. 11)
s_recon = self.student_mae(s_final, mask)
l_s_mae = self.student_mae.mae_loss(s_recon, mask, mask, images, self.cfg.patch_size)
# Bidirectional KD with dynamic weights (Eq. 12)
total_loss, components = self.criterion(
t_feats, s_feats, l_t_mae, l_s_mae,
self.cfg.distill_pairs
)
return total_loss, components
# ─── SECTION 8: Dataset, Training Loop & Smoke Test ──────────────────────────
class SyntheticVHRDataset(Dataset):
"""
Synthetic VHR-RS dataset for testing Mask-CDKD.
Replace with LuoJiaCDKD-100K for production:
100,801 unlabeled 1024×1024 VHR-RS images
Global coverage: Asia 36.4%, Europe 27.6%, N.America 18.5%, ...
Sources: LoveDA, VEDAI, xBD, DeepGlobe Road, LEVIR-CD + acquired
All images stored as RGB, no annotations required
Downstream fine-tuning datasets:
DeepGlobe: https://competitions.codalab.org/competitions/18468
803 images, 2448×2448, 0.5m, 7 LULC classes
Wuhan-1: In-house Wuhan University satellite data
GF-series: Gaofen satellite imagery, Guangdong Province
"""
def __init__(self, n: int = 200, cfg: Optional[MaskCDKDCfg] = None):
self.n = n
self.cfg = cfg or MaskCDKDCfg(tiny=True)
def __len__(self): return self.n
def __getitem__(self, idx):
# Synthetic VHR satellite tile
image = torch.rand(3, self.cfg.img_size, self.cfg.img_size)
return {'image': image}
class SyntheticSegDataset(Dataset):
"""Labeled dataset for downstream fine-tuning evaluation."""
def __init__(self, n: int = 100, cfg: Optional[MaskCDKDCfg] = None):
self.n = n
self.cfg = cfg or MaskCDKDCfg(tiny=True)
def __len__(self): return self.n
def __getitem__(self, idx):
image = torch.rand(3, self.cfg.img_size, self.cfg.img_size)
# Simulated LULC segmentation mask (7 classes for DeepGlobe)
label = torch.randint(0, self.cfg.num_classes,
(self.cfg.img_size, self.cfg.img_size))
return {'image': image, 'label': label}
def get_trainable_params(model: MaskCDKD) -> List:
"""
Return only trainable parameters: teacher MMoA adapters + student + MAE decoders.
Teacher backbone remains frozen throughout distillation.
"""
trainable = []
for name, p in model.named_parameters():
if p.requires_grad:
trainable.append(p)
return trainable
def run_distillation(
model: MaskCDKD,
loader: DataLoader,
device: torch.device,
epochs: int = 3,
) -> List[float]:
"""
Mask-CDKD distillation training loop (Section 5.1.3).
Production: 120 epochs, AdamW lr=1e-5, weight_decay=0.01
"""
trainable = get_trainable_params(model)
opt = torch.optim.AdamW(trainable, lr=model.cfg.lr_distill,
weight_decay=model.cfg.weight_decay)
history = []
model.train()
for ep in range(1, epochs + 1):
ep_loss = 0.0
for batch in loader:
imgs = batch['image'].to(device)
opt.zero_grad()
loss, comps = model(imgs)
loss.backward()
torch.nn.utils.clip_grad_norm_(trainable, 1.0)
opt.step()
ep_loss += loss.item()
avg = ep_loss / max(1, len(loader))
history.append(avg)
stage = (
"Early" if comps['lambda1'] == 0.20 else
"Middle" if comps['lambda1'] == 0.60 else "Late"
)
print(f" Distil Ep {ep}/{epochs} | Loss={avg:.4f} | Stage={stage} | "
f"λ=({comps['lambda1']:.2f},{comps['lambda2']:.2f},{comps['lambda3']:.2f}) | "
f"r={comps['r_ratio']:.3f}")
return history
def compute_miou(preds: Tensor, labels: Tensor, num_classes: int) -> float:
"""mIoU metric for LULC segmentation evaluation (Eq. 13)."""
ious = []
for c in range(num_classes):
tp = ((preds == c) & (labels == c)).float().sum()
fp = ((preds == c) & (labels != c)).float().sum()
fn = ((preds != c) & (labels == c)).float().sum()
iou = (tp / (tp + fp + fn + 1e-8)).item()
ious.append(iou)
return float(np.mean(ious))
if __name__ == "__main__":
print("=" * 70)
print(" Mask-CDKD — Full Smoke Test")
print(" Shu, Zhang et al. (Wuhan University, ISPRS 2026)")
print("=" * 70)
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cpu")
cfg = MaskCDKDCfg(tiny=True)
# ── 1. Build model ───────────────────────────────────────────────────────
print("\n[1/6] Building Mask-CDKD framework...")
model = MaskCDKD(cfg).to(device)
total = sum(p.numel() for p in model.parameters()) / 1e6
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
frozen = total - trainable
print(f" Total: {total:.2f}M | Trainable: {trainable:.2f}M | Frozen backbone: {frozen:.2f}M")
# ── 2. MMoA forward ──────────────────────────────────────────────────────
print("\n[2/6] MMoA-enhanced Transformer block forward pass...")
patch_hw = cfg.img_size // cfg.patch_size
N = patch_hw ** 2
dummy_tokens = torch.randn(2, N, cfg.teacher_embed)
block = model.teacher.blocks[0]
out = block(dummy_tokens)
print(f" Input: {tuple(dummy_tokens.shape)} → Output: {tuple(out.shape)}")
# ── 3. Dynamic loss schedule ──────────────────────────────────────────────
print("\n[3/6] Dynamic weight schedule test...")
criterion = BidirectionalKDLoss(cfg)
for t_mae, s_mae, label in [(0.9, 0.8, "Early"), (0.3, 0.8, "Middle"), (0.8, 0.8, "Late?")]:
l1, l2, l3 = criterion.get_weights(t_mae, s_mae)
r = t_mae / (s_mae + 1e-8)
print(f" {label}: r={r:.2f} → λ=({l1:.2f}, {l2:.2f}, {l3:.2f})")
# ── 4. Full forward pass ────────────────────────────────────────────────
print("\n[4/6] Full Mask-CDKD forward pass (distillation)...")
dummy_imgs = torch.randn(2, 3, cfg.img_size, cfg.img_size)
loss, comps = model(dummy_imgs)
print(f" Total loss: {loss.item():.4f}")
print(f" KD loss: {comps['l_kd']:.4f}")
print(f" Teacher MAE: {comps['l_t_mae']:.4f}")
print(f" Student MAE: {comps['l_s_mae']:.4f}")
print(f" Adaptation stage: λ=({comps['lambda1']:.2f},{comps['lambda2']:.2f},{comps['lambda3']:.2f})")
# ── 5. Short distillation run ────────────────────────────────────────────
print("\n[5/6] Short distillation run (2 epochs)...")
dataset = SyntheticVHRDataset(n=32, cfg=cfg)
loader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True)
run_distillation(model, loader, device, epochs=2)
# ── 6. Downstream evaluation check ──────────────────────────────────────
print("\n[6/6] Downstream segmentation check...")
model.student.eval()
dummy_img = torch.randn(1, 3, cfg.img_size, cfg.img_size)
with torch.no_grad():
s_out = model.student(dummy_img)
feats = s_out['final'] # (1, N, D)
# Simple linear classifier head for smoke test
head = nn.Linear(cfg.student_embed, cfg.num_classes)
logits = head(feats) # (1, N, num_classes)
preds = logits.argmax(-1).reshape(1, patch_hw, patch_hw)
labels = torch.randint(0, cfg.num_classes, (1, patch_hw, patch_hw))
miou = compute_miou(preds, labels, cfg.num_classes)
print(f" Student feature shape: {tuple(feats.shape)}")
print(f" mIoU (random baseline): {miou:.4f}")
print("\n" + "="*70)
print("✓ All checks passed. Mask-CDKD is ready for real VHR-RS data.")
print("="*70)
print("""
Production deployment steps:
1. Teacher setup (SAM ViT-Large):
pip install git+https://github.com/facebookresearch/segment-anything.git
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
from segment_anything import sam_model_registry
sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
# Insert MMoATransformerBlock into sam.image_encoder.blocks
2. LuoJiaCDKD-100K dataset:
100,801 unlabeled 1024×1024 VHR-RS tiles (global coverage)
Available at: https://github.com/whujader/mask_cdkd
No labels needed — only target-domain RGB satellite images
3. Distillation training (120 epochs on H800 80GB GPU):
AdamW, lr=1e-5, weight_decay=0.01, batch=4
Input: 1024×1024 tiles, mask_ratio=0.75, mae_decoder_depth=4
Distill teacher blocks {6,12,18} ↔ student blocks {3,6,9}
Monitor r = L_T_MAE / L_S_MAE for dynamic stage transitions
4. Downstream fine-tuning (30 epochs on V100 32GB GPU):
Freeze student backbone; train UPerNet decoder only
AdamW, lr=1e-4, batch=4
DeepGlobe: 1024×1024 crops, 7 classes → target mIoU 71.56%
Wuhan-1: 1024×1024 crops, 8 classes → target mIoU 59.04%
GF-series: 1024×1024 crops, 11 classes → target mIoU 78.51%
5. Embedded deployment (Jetson Orin NX 16GB):
TensorRT FP16 conversion → .engine file
Throughput: 2.50 FPS at 1024×1024 (2.97 FPS with LuoJiaNET)
Power: 18.95W average (13.21W net over idle baseline)
Accuracy preservation: <0.02% mIoU degradation vs GPU server
6. Evaluation metrics (Section 5.1.1, Eq. 13-16):
mIoU = (1/N) Σ TP_i / (TP_i + FP_i + FN_i)
OA = Σ TP_i / T (pixel-level accuracy)
mF1 = (1/N) Σ 2·Precision_i·Recall_i / (Precision_i + Recall_i)
""")
Paper, Code & Dataset
Mask-CDKD's full implementation, pretrained weights, and the LuoJiaCDKD-100K dataset are available on GitHub. The paper is published in ISPRS Journal of Photogrammetry and Remote Sensing.
Shu, D., Zhang, Z., Huang, X., Wang, R., Jia, N., Fu, X., Yang, B., Wan, F., Lu, J., & Gong, J. (2026). Mask-CDKD: A source-free and label-free cross-domain knowledge distillation framework from SAM for satellite onboard VHR land-cover mapping. ISPRS Journal of Photogrammetry and Remote Sensing, 236, 1–21. https://doi.org/10.1016/j.isprsjprs.2026.03.035
This article is an independent editorial analysis of peer-reviewed research. The PyTorch implementation is an educational adaptation; production use requires SAM weights and the LuoJiaCDKD-100K dataset. Supported by the National Natural Science Foundation of China (42090011, 42271354, 42371367).
Related Posts — You May Like to Read
Explore More on AI Trend Blend
If this article sparked your interest, here is more of what we cover — satellite AI, foundation model distillation, remote sensing, and beyond.
