CRGenNet: Cloud-Free Optical Image Generation Using SAR and Contaminated Optical Data | AI Trend Blend

CRGenNet: How Satellites Can See Through Clouds by Never Assuming the Sky Is Clear

Researchers at the University of Twente built an image generation network that does what every prior method refused to do — accept that your helper image might also be covered in clouds. CRGenNet fuses Sentinel-1 SAR radar data with contaminated Sentinel-2 optical imagery to synthesize missing scenes, and backs it with a new flood-disaster benchmark that finally reflects real-world messiness.

CRGenNet Cloud Removal SAR-Optical Fusion DownUpBlock FusionAttention Sentinel-1 Sentinel-2 GAN TCSEN12 Dataset Disaster Monitoring
Unified multimodal transformer architecture for lunar surface reconstruction: any-to-any translation between grayscale LRO NAC images, digital elevation models, surface normals, and albedo maps using VQ-tokenizers and masked autoencoding at Apollo landing sites

Satellite optical imagery has a dirty secret that nobody in the deep learning community wanted to talk about: most training datasets assume your auxiliary reference image is cloud-free. In the real world — especially during floods, hurricanes, and other disasters that demand timely satellite observation — clouds don’t politely step aside for the control image. The Zhengzhou flood of July 2021 put this problem in sharp relief: when you need to see the flood extent most urgently, clouds cover both the target and the backup. CRGenNet from University of Twente’s ITC faculty is the first method purpose-built to handle this exact nightmare scenario, and the results show it outperforms every tested baseline when the auxiliary optical data is itself contaminated.


The Problem Nobody Wanted to Admit

Every SAR-optical image fusion paper for cloud removal follows roughly the same recipe: take a cloudy optical image at time t1, pair it with an all-weather SAR image, and optionally a cloud-free optical image from another date, then learn to generate the missing scene. The “cloud-free auxiliary optical image” assumption is so standard it barely gets mentioned — it’s baked into dataset construction, into evaluation splits, and into loss functions that make no provision for cloud contamination in the reference signal.

But look at how real satellite tasking works. If your area of interest is hit by a disaster — a flood, a wildfire, a typhoon — you’re trying to generate imagery from dates when the entire region may be persistently cloud-covered for days or weeks. The observation before the event (your t1 “reference”) might have patchy cloud. The observation after (your target t2) is almost certainly under cloud. The SAR radar cuts through all of it, because SAR doesn’t care about weather. But the moment your pipeline assumes it can detect meaningful changes between optical images at t1 and t2, contaminated optical inputs corrupt everything downstream.

CRGenNet solves this in two ways: architecturally, by fusing SAR and optical data at t1 before any temporal reasoning happens (so the network never “trusts” the optical blindly), and via the FusionAttention module, which dynamically evaluates feature quality from each date rather than treating all pixels equally. The paper also introduces TCSEN12, a dataset specifically built around the Zhengzhou flood where contaminated auxiliary optical data is the norm rather than the exception.

The Core Design Philosophy

CRGenNet is structured around a simple but radical assumption: you cannot trust your auxiliary optical image to be clean. The DownUpBlock fuses co-acquired SAR and optical data from t1 first — giving the network an all-weather structural anchor before any temporal comparison is attempted. The FusionAttention module then selectively aligns the t1-fused features with the t2 SAR signal, actively suppressing regions where cloud contamination would mislead the generator. The result is a pipeline that degrades gracefully when inputs are dirty, instead of catastrophically.

The CRGenNet Architecture — A Walk Through Every Block

INPUTS:
  S1_t1  (Sentinel-1 SAR at time t1)   — all-weather radar
  S2_t1  (Sentinel-2 optical at t1)     — possibly cloud-contaminated
  S1_t2  (Sentinel-1 SAR at time t2)   — structural reference for target date
         │
┌────────▼──────────────────────────────────────────────────────────┐
│  STEP 1: INITIAL FEATURE EXTRACTION                                │
│    MiniConvMish(S1_t1)  →  F_sar_t1  (local structural features)  │
│    ConvMishBlock(S1_t2) →  F_sar_t2  (target-date structure)      │
│    Concat(F_sar_t1, S2_t1) → combined_t1  (SAR guides optical)    │
│    ConvMishBlock(combined_t1) → F_t1_fused                        │
└────────┬──────────────────────────────────────────────────────────┘
         │
┌────────▼──────────────────────────────────────────────────────────┐
│  STEP 2: DOWNUPBLOCK  (cross-modal feature extraction, t1)         │
│    Input: F_t1_fused  (SAR + optical from same timestamp)          │
│                                                                    │
│    DownSample ×2:                                                  │
│      ReplicationPad → Conv3×3 → InstanceNorm → Mish               │
│      ReplicationPad → Conv3×3 → InstanceNorm → Mish               │
│      → compressed spatially, noise suppressed                      │
│                                                                    │
│    UpSample ×2:                                                    │
│      ReplicationPad → DeConv3×3 → InstanceNorm → Mish             │
│      ReplicationPad → DeConv3×3 → InstanceNorm → Mish             │
│    + ConvMishBlock × 2 + Conv + Dropout                            │
│    + Residual connection (input + output)                          │
│    → F_downup: rich t1 features, SAR-corrected optical context     │
└────────┬──────────────────────────────────────────────────────────┘
         │
┌────────▼──────────────────────────────────────────────────────────┐
│  STEP 3: FUSIONATTENTION  (temporal cross-attention)               │
│    Input1: F_downup  (t1 SAR+optical features)                     │
│    Input2: F_sar_t2  (t2 SAR structural features)                  │
│                                                                    │
│    Q1,Q2 = Conv(Input1), Conv(Input2) → queries                    │
│    K1,K2 = Conv(Input1), Conv(Input2) → keys                       │
│    V1,V2 = Conv(Input1), Conv(Input2) → values                     │
│                                                                    │
│    Q = L2_norm(concat(Q1,Q2))   K = L2_norm(concat(K1,K2))        │
│    weights = Softmax(QK^T / sqrt(d_k))  — shared attention map     │
│    attention1 = weights · V1                                       │
│    attention2 = weights · V2                                       │
│    y1 = Input1 + γ·attention1  (γ learnable, init=0)              │
│    y2 = Input2 + γ·attention2                                      │
│    Output = Concat(y1, y2)  → temporally-aligned features          │
└────────┬──────────────────────────────────────────────────────────┘
         │
┌────────▼──────────────────────────────────────────────────────────┐
│  STEP 4: ExtendedConvBlock → SwinBlocks (4 scaled-up, multi-scale) │
│    SwinBlock scale 1 → Channel Attention → Upsample               │
│    SwinBlock scale 2 → Channel Attention → Upsample               │
│    SwinBlock scale 3 → Channel Attention → Upsample               │
│    SwinBlock scale 4 → Spatial Attention → Upsample               │
│    (Scaled cosine attention, hierarchical feature representation)  │
└────────┬──────────────────────────────────────────────────────────┘
         │
┌────────▼──────────────────────────────────────────────────────────┐
│  STEP 5: DECODER  (multi-scale feature fusion)                     │
│    Fuse skip connections from each SwinBlock scale                 │
│    Channel Attention on low-resolution, multi-channel features     │
│    Spatial Attention on high-resolution, few-channel features      │
│    Dropout → MiniConvMish → Upsample → Output (cloud-free S2_t2)  │
└────────────────────────────────────────────────────────────────────┘

DISCRIMINATOR: 3 hierarchical sub-discriminators (D1 full, D2 half,
D3 quarter resolution) each with BN → Mish → SpectralNorm →
Channel+Spatial Attention → Sigmoid. Combined: αD1 + βD2 + γD3.

Module 1: DownUpBlock — SAR Corrects Optical Before Anything Else Happens

The DownUpBlock is the architectural expression of a simple insight: before you try to compare features across time (t1 versus t2), you should first integrate the two modalities within the same timestamp. SAR and optical images at t1 are complementary — SAR reliably captures structure and geometry regardless of weather, while optical captures spectral content in cloud-free regions. By concatenating S1_t1 with S2_t1 and immediately fusing them through the DownUpBlock, the network builds a joint representation that is partly immune to cloud contamination in the optical channel.

The block’s architecture is a residual encoder-decoder: two strided convolutional downsampling layers (each with InstanceNorm and Mish activation) compress the spatial resolution and force the network to extract only the most stable features, suppressing high-frequency speckle noise from the SAR. Two transposed-convolutional upsampling layers restore the original resolution, with a residual connection adding the input features back at the end. The key is the dropout layer in the bottleneck — it prevents the network from memorizing cloud-pattern locations and instead generalizes to structural content.

Why Mish Activation Over ReLU?

CRGenNet uses the Mish activation function throughout rather than the standard ReLU. Mish (\(f(x) = x \cdot \tanh(\text{softplus}(x))\)) is smooth and non-monotonic, which means gradients flow more reliably through deep stacks of convolutional layers. It also avoids the “dying neuron” problem that can hurt ReLU networks when dealing with the extreme contrast between cloud-covered and clear pixel regions in satellite imagery — a practical advantage in this setting that several prior SAR-optical fusion methods have overlooked.

Module 2: FusionAttention — Knowing When Not to Trust Your Own Features

Unified multimodal transformer architecture for lunar surface reconstruction: any-to-any translation between grayscale LRO NAC images, digital elevation models, surface normals, and albedo maps using VQ-tokenizers and masked autoencoding at Apollo landing sites

After the DownUpBlock produces a fused t1 representation, CRGenNet still needs to align it with the t2 structural information from S1_t2. This is where prior methods fail: change-detection-based approaches compute pixel-level differences between t1 and t2 optical images, but those differences are meaningless wherever cloud contamination exists in either image.

FusionAttention sidesteps change detection entirely by computing a joint attention map from concatenated queries and keys across both inputs, then applying that same shared attention to the value embeddings of each input separately. The learnable scalar γ (initialized to zero, so the block starts as an identity and gradually learns to route attention) means the module only activates its temporal fusion as training confirms it helps.

Eq. 1–4 — FusionAttention $$Q = L_2\bigl(\text{concat}(Q_1, Q_2)\bigr), \quad K = L_2\bigl(\text{concat}(K_1, K_2)\bigr)$$ $$\text{weights} = \text{Softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right)$$ $$\text{attn}_1 = \text{weights} \cdot V_1, \quad \text{attn}_2 = \text{weights} \cdot V_2$$ $$y_1 = x_1 + \gamma \cdot \text{attn}_1, \quad y_2 = x_2 + \gamma \cdot \text{attn}_2$$

The L2 normalization of the concatenated queries and keys is a subtle but important detail. Without it, one temporal stream could numerically dominate the attention scores, defeating the purpose of joint weighting. With it, the attention weights genuinely reflect cross-modal and cross-temporal feature similarity.

Module 3: Multi-Scale Discriminator with Attention

Unified multimodal transformer architecture for lunar surface reconstruction: any-to-any translation between grayscale LRO NAC images, digital elevation models, surface normals, and albedo maps using VQ-tokenizers and masked autoencoding at Apollo landing sites

The discriminator is a three-branch multi-scale design. D1 operates at full input resolution, D2 at half, and D3 at quarter resolution. Each branch follows the same internal structure: BatchNorm → Mish → SpectralNorm → ChannelAttention → SpatialAttention → Sigmoid. Spectral normalization constrains the Lipschitz constant of the discriminator, which is important for training stability with the Wasserstein loss. The attention mechanisms inside the discriminator are not just there for discrimination quality — the paper’s ablation confirms that removing Spatial Attention from the discriminator raises FID by more than 11 points, a larger degradation than removing it from the generator. A discriminator that notices spatial structure provides better gradient signal back to the generator than one that treats the output as a flat bag of statistics.

The Loss Function: Three Perspectives on Similarity

The generator loss combines an adversarial term (least-squares GAN) with a composite similarity loss that evaluates the quality of generated images from three independent angles.

Eq. 5–10 — Generator Loss $$L_G = L_\text{sim}(I_\text{gen}, I_\text{ref}) + \lambda \cdot L_\text{ls}\bigl(D(I_\text{gen},\ldots), 1\bigr)$$ $$L_\text{sim} = \alpha \cdot L_\text{VGG} + \beta \cdot (1 – L_\text{CS}) + \gamma \cdot (1 – L_\text{MS-SSIM})$$ $$L_\text{VGG} = \frac{1}{W_{i,j}H_{i,j}}\sum_{x,y}\!\bigl(\phi_{i,j}(I_\text{ref})_{x,y} – \phi_{i,j}(I_\text{gen})_{x,y}\bigr)^2$$ $$L_\text{CS} = 1 – \frac{I_\text{gen} \cdot I_\text{ref}}{\|I_\text{gen}\|\,\|I_\text{ref}\|}$$ $$L_\text{MS\text{-}SSIM} = 1 – \text{MS-SSIM}(I_\text{gen}, I_\text{ref})$$

The VGG perceptual loss operates on deep feature maps from a pretrained network, evaluating perceptual similarity in a space where texture and structure are explicitly encoded. The cosine similarity loss penalizes directional misalignment in feature space — important for preserving spectral fidelity (the “color coherence”) of generated images across different land cover types. The MS-SSIM loss checks structural integrity across multiple spatial scales simultaneously, from individual pixel edges to kilometer-scale patterns.

The ablation study pins down what each one contributes: removing VGG loss causes the largest FID drop (to 88.46, versus 72.79 for the full model), confirming it drives perceptual realism. Removing cosine similarity causes the second-largest FID drop. MS-SSIM removal has the smallest individual effect but still degrades PSNR and SSIM — its multi-scale structural constraint turns out to be complementary rather than redundant with the other two.


TCSEN12: A Benchmark Built on Reality, Not Optimism

Most SAR-optical cloud removal datasets are constructed by artificially adding synthetic cloud masks to otherwise clean images, or by carefully curating pairs where the reference image happens to be cloud-free. TCSEN12 rejects both shortcuts. It was built around the Zhengzhou flood event of 20 July 2021 — a region in central China that experienced catastrophic rainfall and where the entire satellite imaging window was disrupted by persistent cloud cover before, during, and after the flood.

The dataset covers longitude 112°E–116°E and latitude 33°N–36.5°N, with Sentinel-1 SAR and Sentinel-2 optical imagery from July–August 2021. The two-date subset contains 4,424 image pairs; the three-date subset contains 1,681 samples. A critical design choice: the “reference” images in TCSEN12 are defined as having less than 5% cloud coverage — not zero percent, because achieving completely cloud-free images is essentially impossible over this region during this period. This is what the authors mean by a “realistic benchmark.”

“The absence of satellite images can lead to inaccurate disaster predictions and poorly planned evacuations. This lack of timely information not only affects disaster response and recovery but also makes it challenging to fully evaluate vegetation growth and agricultural conditions.” — Duan, Belgiu & Stein, ISPRS Journal of Photogrammetry and Remote Sensing, 2026

Results: The Numbers That Matter

TCSEN12 Benchmark Comparison

MethodPSNR ↑SSIM ↑MAE ↓RMSE ↓FID ↓
BicycleGAN22.5600.4730.0500.079128.137
MUNIT18.5140.3270.0940.128121.453
Pix2pix20.3170.3890.0730.100142.605
ResViT21.3310.4760.0670.090233.827
MTS2ONet26.2250.6220.0490.05781.150
HPN-CR26.2250.6340.0490.05887.822
DGMR25.5970.6180.0550.06278.811
TCRNet24.4080.5690.0620.07087.936
CRGenNet (ours)26.9780.6480.0410.05072.789

DTSEN1-2 Cross-Dataset Validation

MethodPSNR ↑SSIM ↑RMSE ↓
MTS2ONet36.0430.9900.017
CRGenNet (ours)88.9260.99990.010

The DTSEN1-2 numbers look almost unbelievably good (PSNR of 88.93 versus MTS2ONet’s 36.04). This reflects that the DTSEN1-2 dataset is less challenging — temporally adjacent image pairs with stable surface conditions and SSIM of roughly 0.96 between adjacent observations. CRGenNet’s SAR-guided architecture essentially reconstructs near-identical scenes perfectly in this low-variability setting. The meaningful test is TCSEN12, where the gap is smaller but the conditions are real.

Computational Efficiency

MethodParameters (M)GFLOPsFPS
ResViT123.45120.6290.40
MTS2ONet33.27134.9169.37
HPN-CR3.5533.1343.49
DGMR1.73226.0221.70
CRGenNet (ours)41.25122.2889.09

Complete End-to-End CRGenNet Implementation (PyTorch)

The implementation covers all components from the paper organized into 10 sections: convolutional building blocks (MiniConvMish, ConvMishBlock, ExtendedConvBlock), the DownUpBlock cross-modal encoder-decoder, FusionAttention mechanism, Swin-based feature extraction, multi-scale decoder with channel and spatial attention, the full Generator, multi-scale Discriminator with spectral normalization, the composite loss function (VGG + Cosine + MS-SSIM + WGAN-GP), the complete CRGenNet wrapper, and a training loop with a synthetic smoke test.

# ==============================================================================
# CRGenNet: Cloud-Resilient Generation Network for Optical Image Synthesis
# Paper: ISPRS Journal of Photogrammetry and Remote Sensing 236 (2026) 255-272
# Authors: Chenxi Duan, Mariana Belgiu, Alfred Stein
# Affiliation: University of Twente, ITC Faculty, The Netherlands
# DOI: https://doi.org/10.1016/j.isprsjprs.2026.03.042
# ==============================================================================
# Sections:
#   1.  Imports & Configuration
#   2.  Convolutional Blocks (MiniConvMish, ConvMishBlock, ExtendedConvBlock)
#   3.  Channel & Spatial Attention Modules
#   4.  DownUpBlock (cross-modal SAR-optical encoder-decoder, Section 2.2.1)
#   5.  FusionAttention (cross-temporal joint attention, Eq. 1-4)
#   6.  Swin-based Feature Extraction (simplified scaled-up SwinBlocks)
#   7.  Multi-scale Decoder with Channel & Spatial Attention
#   8.  Full Generator
#   9.  Multi-scale Discriminator with Spectral Normalization
#  10.  Loss Functions (VGG perceptual + Cosine + MS-SSIM + WGAN-GP, Eq. 5-11)
#  11.  CRGenNet Wrapper + Training Loop
#  12.  Smoke Test
# ==============================================================================

from __future__ import annotations
import math, warnings
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.utils import spectral_norm
warnings.filterwarnings("ignore")


# ─── SECTION 1: Configuration ─────────────────────────────────────────────────

class CRGenNetConfig:
    """
    CRGenNet hyperparameters matching the paper's TCSEN12 training setup.
    Adjust channels/depth to balance accuracy vs. compute budget.
    """
    # Image dimensions (Sentinel-2 patches, typical crop size)
    img_size: int = 256
    in_channels_s1: int = 2         # Sentinel-1: VV + VH polarizations
    in_channels_s2: int = 3         # Sentinel-2: RGB bands (or more for multispectral)
    out_channels: int = 3           # Generated cloud-free optical image

    # Network width
    base_ch: int = 64               # Base channel count (paper uses 64)
    swin_embed_dim: int = 96        # Swin transformer embedding dimension
    num_swin_blocks: int = 4        # 4 scaled-up SwinBlocks as in paper
    window_size: int = 8            # Swin window size

    # Training
    lr_g: float = 0.001             # Generator learning rate
    lr_d: float = 0.001             # Discriminator learning rate
    beta1: float = 0.9
    beta2: float = 0.999
    lambda_adv: float = 1.0        # λ weight for adversarial loss
    lambda_gp: float = 10.0        # Gradient penalty weight
    alpha_vgg: float = 1.0         # VGG perceptual loss weight
    beta_cs: float = 1.0           # Cosine similarity loss weight
    gamma_ssim: float = 1.0        # MS-SSIM loss weight
    num_epochs: int = 200
    batch_size: int = 8

    def __init__(self, **kwargs):
        for k, v in kwargs.items(): setattr(self, k, v)


# ─── SECTION 2: Convolutional Building Blocks ─────────────────────────────────

class Mish(nn.Module):
    """
    Mish activation: f(x) = x * tanh(softplus(x)).
    Smooth, non-monotonic. Better gradient flow than ReLU in deep stacks,
    especially across extreme contrast ratios (cloud vs. clear sky pixels).
    """
    def forward(self, x: Tensor) -> Tensor:
        return x * torch.tanh(F.softplus(x))


class MiniConvMish(nn.Module):
    """
    Lightest building block: ReplicationPad → Conv3×3 → Mish.
    No normalization — used for initial feature extraction where
    instance statistics may be misleading (e.g., radar backscatter).
    """
    def __init__(self, in_ch: int, out_ch: int, stride: int = 1):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReplicationPad2d(1),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, bias=False),
            Mish(),
        )
    def forward(self, x: Tensor) -> Tensor: return self.block(x)


class ConvMishBlock(nn.Module):
    """
    Standard residual block: ReplicationPad → Conv3×3 → InstanceNorm → Mish
    (repeated twice). Used throughout Generator and DownUpBlock.
    InstanceNorm is chosen over BatchNorm for single-image inference stability.
    """
    def __init__(self, in_ch: int, out_ch: int, stride: int = 1, use_deconv: bool = False):
        super().__init__()
        conv_cls = nn.ConvTranspose2d if use_deconv else nn.Conv2d
        if use_deconv:
            self.block = nn.Sequential(
                nn.ReplicationPad2d(1),
                nn.ConvTranspose2d(in_ch, out_ch, kernel_size=3, stride=stride,
                                   padding=1, output_padding=stride-1, bias=False),
                nn.InstanceNorm2d(out_ch),
                Mish(),
                nn.ReplicationPad2d(1),
                nn.ConvTranspose2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
                nn.InstanceNorm2d(out_ch),
                Mish(),
            )
        else:
            self.block = nn.Sequential(
                nn.ReplicationPad2d(1),
                nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, bias=False),
                nn.InstanceNorm2d(out_ch),
                Mish(),
                nn.ReplicationPad2d(1),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, bias=False),
                nn.InstanceNorm2d(out_ch),
                Mish(),
            )
        self.skip = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x: Tensor) -> Tensor:
        return self.block(x) + self.skip(x)


class ExtendedConvBlock(nn.Module):
    """
    MiniConvMish + 1×1 Conv + Tanh. Used as the interface layer between
    FusionAttention and SwinBlocks, mapping features to a bounded range.
    """
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReplicationPad2d(1),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, bias=False),
            Mish(),
            nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False),
            nn.Tanh(),
        )
    def forward(self, x: Tensor) -> Tensor: return self.block(x)


# ─── SECTION 3: Channel & Spatial Attention ────────────────────────────────────

class ChannelAttention(nn.Module):
    """
    Squeeze-and-Excitation style channel attention.
    Applied to low-resolution, high-channel-count features in the Decoder.
    Forces the network to emphasize spectral channels critical for
    reconstructing land cover (vegetation NDVI, water reflectance, etc.).
    """
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        mid = max(1, channels // reduction)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, mid, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid, channels, 1, bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: Tensor) -> Tensor:
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        return x * self.sigmoid(avg_out + max_out)


class SpatialAttention(nn.Module):
    """
    Convolutional spatial attention.
    Applied to high-resolution, low-channel features.
    Helps the decoder focus on cloud-contaminated spatial regions —
    ablation shows this is the single most impactful attention component
    (removal raises FID by 11.5 points more than channel attention removal).
    """
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: Tensor) -> Tensor:
        avg_out = x.mean(dim=1, keepdim=True)
        max_out, _ = x.max(dim=1, keepdim=True)
        attn = self.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1)))
        return x * attn


# ─── SECTION 4: DownUpBlock ────────────────────────────────────────────────────

class DownUpBlock(nn.Module):
    """
    DownUpBlock: Cross-modal SAR-optical feature extraction (Section 2.2.1).

    Purpose: Extract a joint representation from co-acquired SAR and optical
    data at t1 BEFORE any temporal comparison. This gives CRGenNet an
    all-weather structural baseline that is immune to cloud contamination.

    Architecture:
      Down×2 (Conv, stride=2): spatial compression, noise suppression
      Up×2   (DeConv, stride=2): spatial restoration
      ConvMishBlock×2: further refinement
      Conv1×1: channel bottleneck
      Dropout: regularization, prevents cloud-location memorization
      Residual: adds input back (starts as identity, learns refinements)

    The hierarchical downsampling suppresses SAR speckle noise by forcing
    the encoder to retain only the most stable spatial structures.
    """
    def __init__(self, in_ch: int, out_ch: int, dropout_p: float = 0.2):
        super().__init__()
        mid = out_ch

        # Two down-sampling steps (Conv, stride=2)
        self.down1 = ConvMishBlock(in_ch, mid, stride=2)
        self.down2 = ConvMishBlock(mid, mid, stride=2)

        # Two up-sampling steps (DeConv / transposed conv, stride=2)
        self.up1 = ConvMishBlock(mid, mid, stride=2, use_deconv=True)
        self.up2 = ConvMishBlock(mid, out_ch, stride=2, use_deconv=True)

        # Post-upsample refinement (per Fig. 4)
        self.refine = nn.Sequential(
            ConvMishBlock(out_ch, out_ch),
            ConvMishBlock(out_ch, out_ch),
            nn.Conv2d(out_ch, out_ch, kernel_size=1),
            nn.Dropout2d(p=dropout_p),
        )

        # Residual projection if channel counts differ
        self.skip = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x: Tensor) -> Tensor:
        """
        x: (B, in_ch, H, W) — concatenated SAR and (cloud-affected) optical
        Returns: (B, out_ch, H, W) — joint SAR-guided feature map
        """
        d1 = self.down1(x)      # (B, mid, H/2, W/2)
        d2 = self.down2(d1)     # (B, mid, H/4, W/4)
        u1 = self.up1(d2)       # (B, mid, H/2, W/2)
        u2 = self.up2(u1)       # (B, out_ch, H, W)
        out = self.refine(u2)
        return out + self.skip(x)   # residual connection


# ─── SECTION 5: FusionAttention ───────────────────────────────────────────────

class FusionAttention(nn.Module):
    """
    FusionAttention: Cross-temporal joint attention mechanism (Section 2.2.1, Eq. 1-4).

    Intelligently aligns features from two dates without relying on
    explicit change detection (which breaks under cloud contamination).

    Key design choices:
    - SHARED attention weights: Q,K are concatenated from both inputs and
      L2-normalized → single weight map that sees both perspectives at once
    - SEPARATE value projections: V1,V2 allow each date's content to be
      expressed differently while sharing the same spatial routing
    - γ initialized to 0: block starts as identity, learns to route only
      where evidence supports temporal fusion
    """
    def __init__(self, in_ch: int):
        super().__init__()
        # Query, key, value projections for each input stream
        self.q1 = nn.Conv2d(in_ch, in_ch, kernel_size=1)
        self.q2 = nn.Conv2d(in_ch, in_ch, kernel_size=1)
        self.k1 = nn.Conv2d(in_ch, in_ch, kernel_size=1)
        self.k2 = nn.Conv2d(in_ch, in_ch, kernel_size=1)
        self.v1 = nn.Conv2d(in_ch, in_ch, kernel_size=1)
        self.v2 = nn.Conv2d(in_ch, in_ch, kernel_size=1)

        # After attention, fuse back to in_ch (was 2*in_ch concatenated Q,K)
        self.fuse = nn.Conv2d(in_ch * 2, in_ch, kernel_size=1)

        # γ: learnable scalar, initialized to 0 (identity start, Eq. 4)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.scale = None   # set dynamically based on d_k

    def forward(self, x1: Tensor, x2: Tensor) -> Tuple[Tensor, Tensor]:
        """
        x1: (B, C, H, W) — t1 fused SAR+optical features (from DownUpBlock)
        x2: (B, C, H, W) — t2 SAR features (structural reference for target date)

        Returns: (y1, y2) each (B, C, H, W)
          y1 = x1 enriched with cross-temporal attention from x2
          y2 = x2 enriched with cross-temporal attention from x1
        """
        B, C, H, W = x1.shape
        d_k = C * 2   # dimension after concat (for scaling)

        # Project to Q, K, V for each input (Eq. 1)
        Q1 = self.q1(x1).reshape(B, C, -1)   # (B, C, H*W)
        Q2 = self.q2(x2).reshape(B, C, -1)
        K1 = self.k1(x1).reshape(B, C, -1)
        K2 = self.k2(x2).reshape(B, C, -1)
        V1 = self.v1(x1).reshape(B, C, -1)   # (B, C, H*W)
        V2 = self.v2(x2).reshape(B, C, -1)

        # Joint Q, K with L2 normalization (Eq. 1)
        Q = torch.cat([Q1, Q2], dim=1)  # (B, 2C, H*W)
        K = torch.cat([K1, K2], dim=1)
        Q = F.normalize(Q, p=2, dim=1)
        K = F.normalize(K, p=2, dim=1)

        # Scaled dot-product attention (Eq. 2)
        # Q: (B, 2C, HW), K: (B, 2C, HW) → weights: (B, HW, HW)
        attn = torch.bmm(Q.permute(0,2,1), K) / math.sqrt(d_k)
        weights = F.softmax(attn, dim=-1)   # (B, HW, HW)

        # Apply shared weights to each value stream (Eq. 3)
        # V: (B, C, HW) → after bmm: (B, C, HW) → reshape to (B, C, H, W)
        attn1 = torch.bmm(V1, weights.permute(0,2,1)).reshape(B, C, H, W)
        attn2 = torch.bmm(V2, weights.permute(0,2,1)).reshape(B, C, H, W)

        # Residual update with learnable γ (Eq. 4)
        y1 = x1 + self.gamma * attn1
        y2 = x2 + self.gamma * attn2
        return y1, y2


# ─── SECTION 6: Swin-based Feature Extraction ────────────────────────────────

class WindowAttention(nn.Module):
    """
    Simplified window-based multi-head self-attention (Swin-style).
    Full Swin implementation uses shifted windows and relative position
    bias; this version uses standard local attention for clarity.
    For production use, install timm and use SwinTransformerBlock directly.
    """
    def __init__(self, dim: int, window_size: int = 8, num_heads: int = 4):
        super().__init__()
        self.ws = window_size
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
        )

    def forward(self, x: Tensor) -> Tensor:
        """x: (B, C, H, W)"""
        B, C, H, W = x.shape
        # Fold into windows, apply attention, unfold
        x_flat = x.permute(0, 2, 3, 1).reshape(B, H*W, C)
        attn_out, _ = self.attn(self.norm1(x_flat), self.norm1(x_flat), self.norm1(x_flat))
        x_flat = x_flat + attn_out
        x_flat = x_flat + self.ffn(self.norm2(x_flat))
        return x_flat.reshape(B, H, W, C).permute(0, 3, 1, 2)


class ScaledSwinBlock(nn.Module):
    """
    Single scale of the 4-scale Swin feature extraction tower.
    Patch partition → Window attention → optional downsampling.
    Paper uses "scaled-up Swin Transformer architecture" with
    increased capacity and scaled cosine attention.
    """
    def __init__(self, in_ch: int, out_ch: int, downsample: bool = True):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, out_ch, kernel_size=1)
        self.attn = WindowAttention(out_ch)
        self.down = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(out_ch),
            Mish(),
        ) if downsample else nn.Identity()

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        x = self.proj(x)
        x = self.attn(x)
        skip = x                # skip connection for decoder
        x = self.down(x)
        return x, skip          # (downsampled_feat, skip_feat)


class SwinTower(nn.Module):
    """
    4-scale SwinBlock tower producing hierarchical skip connections
    for the multi-scale decoder. Channels double at each scale.
    """
    def __init__(self, in_ch: int, base_ch: int, num_scales: int = 4):
        super().__init__()
        channels = [base_ch * (2**i) for i in range(num_scales)]
        in_chs = [in_ch] + channels[:-1]
        self.scales = nn.ModuleList([
            ScaledSwinBlock(in_chs[i], channels[i], downsample=(i < num_scales-1))
            for i in range(num_scales)
        ])
        self.out_channels = channels

    def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
        skips = []
        for scale in self.scales:
            x, skip = scale(x)
            skips.append(skip)
        return x, skips           # (bottleneck, [skip_1, ..., skip_4])


# ─── SECTION 7: Multi-scale Decoder ───────────────────────────────────────────

class DecoderBlock(nn.Module):
    """
    Single decoder stage: upsample → concat skip → channel/spatial attention.
    Channel Attention for low-res / high-channel features (early decoder).
    Spatial Attention for high-res / low-channel features (late decoder).
    """
    def __init__(self, in_ch: int, skip_ch: int, out_ch: int, use_channel_att: bool = True):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            MiniConvMish(in_ch, out_ch),
        )
        self.conv = MiniConvMish(out_ch + skip_ch, out_ch)
        self.att = ChannelAttention(out_ch) if use_channel_att else SpatialAttention()

    def forward(self, x: Tensor, skip: Tensor) -> Tensor:
        x = self.up(x)
        # Handle spatial size mismatch (e.g., odd input sizes)
        if x.shape[-2:] != skip.shape[-2:]:
            x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
        x = self.conv(torch.cat([x, skip], dim=1))
        x = self.att(x)
        return x


class MultiScaleDecoder(nn.Module):
    """
    Hierarchical decoder matching the SwinTower's 4-scale skip connections.
    First 3 stages use Channel Attention (low-res, many channels).
    Final stage uses Spatial Attention (full-res, few channels).
    """
    def __init__(self, swin_channels: List[int], out_ch: int):
        super().__init__()
        chs = list(reversed(swin_channels))   # [C_deep, ..., C_shallow]
        self.blocks = nn.ModuleList()
        for i in range(len(chs)):
            in_ch = chs[i]
            skip_ch = chs[i+1] if i + 1 < len(chs) else chs[-1]
            o_ch = chs[i+1] if i + 1 < len(chs) else chs[-1]
            use_ca = (i < len(chs) - 1)    # Channel Att for all but last stage
            self.blocks.append(DecoderBlock(in_ch, skip_ch, o_ch, use_channel_att=use_ca))
        final_ch = chs[-1]
        self.head = nn.Sequential(
            nn.Dropout2d(p=0.1),
            MiniConvMish(final_ch, final_ch // 2),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            MiniConvMish(final_ch // 2, out_ch),
            nn.Tanh(),              # output range [-1, 1], matches Sentinel-2 normalized reflectance
        )

    def forward(self, bottleneck: Tensor, skips: List[Tensor]) -> Tensor:
        x = bottleneck
        skips_rev = list(reversed(skips))
        for i, block in enumerate(self.blocks):
            skip = skips_rev[i] if i < len(skips_rev) else x
            x = block(x, skip)
        return self.head(x)


# ─── SECTION 8: Full Generator ────────────────────────────────────────────────

class CRGenNetGenerator(nn.Module):
    """
    CRGenNet Generator: produces cloud-free optical image at t2 (Section 2.2.1).

    Input streams:
      S1_t1 (Sentinel-1 SAR at t1) + S2_t1 (contaminated Sentinel-2 at t1)
      S1_t2 (Sentinel-1 SAR at t2) — provides target-date structural reference

    Data flow:
      1. MiniConvMish on S1_t1 → feature_sar_t1
      2. ConvMishBlock on S1_t2 → feature_sar_t2
      3. Concat(feature_sar_t1, S2_t1) → ConvMishBlock → fused_t1
      4. DownUpBlock(fused_t1) → downup_feat (SAR-corrected, noise-suppressed)
      5. FusionAttention(downup_feat, feature_sar_t2) → (y1, y2)
      6. ExtendedConvBlock(concat(y1, y2)) → ext_feat
      7. SwinTower(ext_feat) → bottleneck + 4 skip connections
      8. MultiScaleDecoder(bottleneck, skips) → cloud-free S2_t2
    """

    def __init__(self, cfg: CRGenNetConfig):
        super().__init__()
        B = cfg.base_ch
        s1_ch = cfg.in_channels_s1
        s2_ch = cfg.in_channels_s2

        # Step 1: Initial feature extraction
        self.feat_s1_t1 = MiniConvMish(s1_ch, B)
        self.feat_s1_t2 = ConvMishBlock(s1_ch, B)

        # Step 2: Early fusion (SAR_t1 guides optical_t1)
        self.early_fuse = ConvMishBlock(B + s2_ch, B)

        # Step 3: DownUpBlock — cross-modal SAR-optical encoder-decoder
        self.downup = DownUpBlock(B, B)

        # Step 4: FusionAttention — cross-temporal alignment
        self.fusion_att = FusionAttention(B)

        # Step 5: Interface to SwinTower
        self.ext_conv = ExtendedConvBlock(B * 2, B)

        # Step 6: Multi-scale feature extraction
        self.swin_tower = SwinTower(B, B, num_scales=4)

        # Step 7: Hierarchical decoder
        self.decoder = MultiScaleDecoder(self.swin_tower.out_channels, cfg.out_channels)

    def forward(self, s1_t1: Tensor, s2_t1: Tensor, s1_t2: Tensor) -> Tensor:
        """
        s1_t1: (B, 2, H, W)  Sentinel-1 at t1
        s2_t1: (B, 3, H, W)  Sentinel-2 at t1 (may be cloud-contaminated)
        s1_t2: (B, 2, H, W)  Sentinel-1 at t2

        Returns: I_gen (B, 3, H, W) — cloud-free Sentinel-2 at t2
        """
        # Feature extraction
        f_sar_t1 = self.feat_s1_t1(s1_t1)     # (B, B, H, W)
        f_sar_t2 = self.feat_s1_t2(s1_t2)     # (B, B, H, W)

        # SAR-guided early fusion of t1 modalities
        fused_t1 = self.early_fuse(torch.cat([f_sar_t1, s2_t1], dim=1))

        # DownUpBlock: hierarchical noise suppression and cross-modal integration
        downup_feat = self.downup(fused_t1)    # (B, B, H, W)

        # FusionAttention: cross-temporal alignment (no change detection!)
        y1, y2 = self.fusion_att(downup_feat, f_sar_t2)

        # Merge and project to SwinTower input
        ext_feat = self.ext_conv(torch.cat([y1, y2], dim=1))

        # Hierarchical feature extraction
        bottleneck, skips = self.swin_tower(ext_feat)

        # Hierarchical decoding with skip connections and attention
        output = self.decoder(bottleneck, skips)
        return output


# ─── SECTION 9: Multi-scale Discriminator ─────────────────────────────────────

class SubDiscriminator(nn.Module):
    """
    Single branch of the multi-scale discriminator (one of D1, D2, D3).
    Architecture: BN → Mish → SpectralNorm → ChannelAtt → SpatialAtt → Sigmoid.
    Spectral normalization enforces Lipschitz constraint for WGAN stability.
    Attention helps detect subtle differences in cloud-removal quality.
    """
    def __init__(self, in_ch: int, base_ch: int = 64):
        super().__init__()
        def dis_block(ic, oc, stride=2):
            return nn.Sequential(
                spectral_norm(nn.Conv2d(ic, oc, 4, stride=stride, padding=1)),
                nn.BatchNorm2d(oc),
                Mish(),
            )
        self.net = nn.Sequential(
            dis_block(in_ch, base_ch, stride=2),
            dis_block(base_ch, base_ch*2, stride=2),
            dis_block(base_ch*2, base_ch*4, stride=2),
            dis_block(base_ch*4, base_ch*8, stride=1),
        )
        self.channel_att = ChannelAttention(base_ch*8)
        self.spatial_att = SpatialAttention()
        self.out = spectral_norm(nn.Conv2d(base_ch*8, 1, 4, padding=1))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: Tensor) -> Tensor:
        feat = self.net(x)
        feat = self.channel_att(feat)
        feat = self.spatial_att(feat)
        return self.sigmoid(self.out(feat))


class MultiScaleDiscriminator(nn.Module):
    """
    3-branch discriminator operating at full, half, and quarter resolution.
    Combined decision: αD1 + βD2 + γD3 (equal weights by default).
    Multi-scale design captures both global structure and local texture quality.

    Inputs to discriminator: (I_gen/I_ref, S1_t1, S1_t2, S2_t1) concatenated,
    so the discriminator sees the full context, not just the generated image.
    """
    def __init__(self, in_channels: int, base_ch: int = 64):
        super().__init__()
        self.d1 = SubDiscriminator(in_channels, base_ch)
        self.d2 = SubDiscriminator(in_channels, base_ch)
        self.d3 = SubDiscriminator(in_channels, base_ch)
        self.alpha = nn.Parameter(torch.ones(1) / 3)
        self.beta  = nn.Parameter(torch.ones(1) / 3)
        self.gamma_w = nn.Parameter(torch.ones(1) / 3)

    def forward(self, x: Tensor) -> Tensor:
        """
        x: (B, C_total, H, W) — generated or real image concatenated with conditioning
        Returns scalar decision per spatial location (averaged over scales).
        """
        d1_out = self.d1(x)
        d2_out = self.d2(F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False))
        d3_out = self.d3(F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=False))

        # Resize d2, d3 to d1 spatial resolution before weighted sum
        d2_up = F.interpolate(d2_out, size=d1_out.shape[-2:], mode='bilinear', align_corners=False)
        d3_up = F.interpolate(d3_out, size=d1_out.shape[-2:], mode='bilinear', align_corners=False)
        return self.alpha * d1_out + self.beta * d2_up + self.gamma_w * d3_up


# ─── SECTION 10: Loss Functions ───────────────────────────────────────────────

class VGGPerceptualLoss(nn.Module):
    """
    VGG perceptual loss (Eq. 7). Uses features from VGG16 relu3_3 layer.
    Evaluates perceptual similarity rather than pixel-level difference,
    critical for preserving texture realism in satellite imagery generation.
    """
    def __init__(self):
        super().__init__()
        try:
            import torchvision.models as models
            vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
            self.features = nn.Sequential(*list(vgg.features)[:(16+1)]).eval()
            for p in self.features.parameters():
                p.requires_grad = False
            self.ready = True
        except Exception:
            self.ready = False
            print("[VGGLoss] torchvision not available, falling back to L1 loss.")

    def forward(self, gen: Tensor, ref: Tensor) -> Tensor:
        if not self.ready:
            return F.l1_loss(gen, ref)
        # Normalize to ImageNet mean/std for VGG
        mean = torch.tensor([0.485, 0.456, 0.406], device=gen.device).view(1,3,1,1)
        std  = torch.tensor([0.229, 0.224, 0.225], device=gen.device).view(1,3,1,1)
        gen_n = (gen * 0.5 + 0.5 - mean) / std   # [-1,1] → [0,1] → normalized
        ref_n = (ref * 0.5 + 0.5 - mean) / std
        # Match VGG input channels (if S2 uses more than 3 bands, take first 3)
        gen_3 = gen_n[:, :3]
        ref_3 = ref_n[:, :3]
        f_gen = self.features(gen_3)
        f_ref = self.features(ref_3)
        return F.mse_loss(f_gen, f_ref)


def cosine_similarity_loss(gen: Tensor, ref: Tensor) -> Tensor:
    """
    Cosine Similarity Loss (Eq. 8): L_CS = 1 - cos(I_gen, I_ref).
    Measures directional alignment of feature vectors in pixel space.
    Enforces global color coherence and spectral fidelity.
    """
    gen_flat = gen.reshape(gen.shape[0], -1)
    ref_flat = ref.reshape(ref.shape[0], -1)
    cos_sim = F.cosine_similarity(gen_flat, ref_flat, dim=1)
    return (1.0 - cos_sim).mean()


def ms_ssim_loss(gen: Tensor, ref: Tensor, num_scales: int = 3) -> Tensor:
    """
    MS-SSIM Loss (Eq. 9): 1 - MS-SSIM(I_gen, I_ref).
    Evaluates structural similarity across spatial scales.
    Multi-scale formulation captures differences from pixel edges to
    kilometer-scale patterns in satellite imagery.
    """
    total_ssim = torch.tensor(0.0, device=gen.device)
    g, r = gen, ref
    for scale in range(num_scales):
        # Per-scale SSIM (simplified: using mean and variance)
        mu_g = F.avg_pool2d(g, kernel_size=11, stride=1, padding=5)
        mu_r = F.avg_pool2d(r, kernel_size=11, stride=1, padding=5)
        sigma_g = F.avg_pool2d(g**2, kernel_size=11, stride=1, padding=5) - mu_g**2
        sigma_r = F.avg_pool2d(r**2, kernel_size=11, stride=1, padding=5) - mu_r**2
        sigma_gr = F.avg_pool2d(g * r, kernel_size=11, stride=1, padding=5) - mu_g * mu_r
        C1, C2 = 0.01**2, 0.03**2
        ssim = ((2 * mu_g * mu_r + C1) * (2 * sigma_gr + C2)) / \
               ((mu_g**2 + mu_r**2 + C1) * (sigma_g + sigma_r + C2))
        total_ssim = total_ssim + ssim.mean()
        # Downsample for next scale
        if scale < num_scales - 1:
            g = F.avg_pool2d(g, kernel_size=2)
            r = F.avg_pool2d(r, kernel_size=2)
    return 1.0 - total_ssim / num_scales


class CRGenNetLoss(nn.Module):
    """
    Complete CRGenNet loss functions (Section 2.2.3, Eq. 5-11).

    Generator Loss:
      L_G = L_sim(I_gen, I_ref) + λ·L_ls(D(I_gen,...), 1)
      L_sim = α·L_VGG + β·(1 - L_CS) + γ·(1 - L_MS-SSIM)

    Discriminator Loss (WGAN with gradient penalty):
      L_D = L_real + L_fake + λ_gp · L_gp

    Ablation insight: L_VGG contributes most to perceptual quality (removes
    11+ FID points when missing). L_CS maintains spectral coherence.
    L_MS-SSIM preserves structural fidelity at multiple spatial scales.
    """
    def __init__(self, cfg: CRGenNetConfig):
        super().__init__()
        self.vgg_loss = VGGPerceptualLoss()
        self.alpha_vgg = cfg.alpha_vgg
        self.beta_cs   = cfg.beta_cs
        self.gamma_s   = cfg.gamma_ssim
        self.lambda_adv = cfg.lambda_adv
        self.lambda_gp  = cfg.lambda_gp

    def similarity_loss(self, gen: Tensor, ref: Tensor) -> Tensor:
        """L_sim = α·L_VGG + β·(1-L_CS) + γ·(1-L_MS-SSIM). Eq. 6"""
        l_vgg  = self.vgg_loss(gen, ref)
        l_cs   = cosine_similarity_loss(gen, ref)
        l_ssim = ms_ssim_loss(gen, ref)
        return self.alpha_vgg * l_vgg + self.beta_cs * l_cs + self.gamma_s * l_ssim

    def least_squares_loss(self, pred: Tensor, target_real: bool = True) -> Tensor:
        """L_ls: least squares GAN loss (Eq. 10). More stable than cross-entropy GAN."""
        target = torch.ones_like(pred) if target_real else torch.zeros_like(pred)
        return F.mse_loss(pred, target)

    def generator_loss(self, gen: Tensor, ref: Tensor, d_fake: Tensor) -> Dict[str, Tensor]:
        """
        Full generator loss (Eq. 5).
        L_G = L_sim + λ·L_ls(D(I_gen,...), 1)
        """
        l_sim = self.similarity_loss(gen, ref)
        l_adv = self.least_squares_loss(d_fake, target_real=True)
        l_g   = l_sim + self.lambda_adv * l_adv
        return {'total': l_g, 'sim': l_sim, 'adv': l_adv}

    def gradient_penalty(
        self,
        discriminator: nn.Module,
        real: Tensor,
        fake: Tensor,
        cond: Tensor,
    ) -> Tensor:
        """
        WGAN gradient penalty (Eq. 11).
        Generates interpolated samples between real and fake,
        computes ||∇D(interpolated)||_2 and penalizes deviation from 1.
        """
        B = real.shape[0]
        alpha = torch.rand(B, 1, 1, 1, device=real.device)
        interp = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
        d_interp = discriminator(torch.cat([interp, cond], dim=1))
        grads = torch.autograd.grad(
            outputs=d_interp, inputs=interp,
            grad_outputs=torch.ones_like(d_interp),
            create_graph=True, retain_graph=True
        )[0]
        grads = grads.reshape(B, -1)
        return ((grads.norm(2, dim=1) - 1) ** 2).mean()

    def discriminator_loss(
        self,
        discriminator: nn.Module,
        real: Tensor,
        fake: Tensor,
        cond: Tensor,
    ) -> Tensor:
        """
        WGAN discriminator loss with gradient penalty (Eq. 11).
        L_D = -E[D(real)] + E[D(fake)] + λ_gp · L_gp
        """
        d_real = discriminator(torch.cat([real.detach(), cond], dim=1))
        d_fake = discriminator(torch.cat([fake.detach(), cond], dim=1))
        l_real = -d_real.mean()
        l_fake =  d_fake.mean()
        l_gp   = self.gradient_penalty(discriminator, real.detach(), fake.detach(), cond)
        return l_real + l_fake + self.lambda_gp * l_gp


# ─── SECTION 11: CRGenNet Wrapper + Training Loop ─────────────────────────────

class CRGenNet(nn.Module):
    """
    CRGenNet: complete wrapper combining Generator and Discriminator.
    Provides clean forward/predict interface and bundles loss computation.

    Usage:
      model = CRGenNet(cfg)
      I_gen = model.generate(s1_t1, s2_t1, s1_t2)   # inference
      losses = model.train_step(...)                   # training
    """
    def __init__(self, cfg: Optional[CRGenNetConfig] = None):
        super().__init__()
        cfg = cfg or CRGenNetConfig()
        self.cfg = cfg
        self.generator = CRGenNetGenerator(cfg)
        # Discriminator input: generated/real image + conditioning inputs
        dis_in_ch = cfg.out_channels + cfg.in_channels_s1 * 2 + cfg.in_channels_s2
        self.discriminator = MultiScaleDiscriminator(dis_in_ch)
        self.loss_fn = CRGenNetLoss(cfg)

    def generate(self, s1_t1: Tensor, s2_t1: Tensor, s1_t2: Tensor) -> Tensor:
        """Inference-only: generate cloud-free optical image at t2."""
        return self.generator(s1_t1, s2_t1, s1_t2)

    def train_step_g(
        self,
        s1_t1: Tensor, s2_t1: Tensor, s1_t2: Tensor, s2_t2_ref: Tensor,
        optimizer_g: torch.optim.Optimizer,
    ) -> Dict[str, float]:
        """Single generator training step."""
        self.generator.train()
        optimizer_g.zero_grad()
        I_gen = self.generator(s1_t1, s2_t1, s1_t2)
        cond = torch.cat([s1_t1, s1_t2, s2_t1], dim=1)
        d_fake = self.discriminator(torch.cat([I_gen, cond], dim=1))
        losses = self.loss_fn.generator_loss(I_gen, s2_t2_ref, d_fake)
        losses['total'].backward()
        torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 5.0)
        optimizer_g.step()
        return {k: v.item() for k, v in losses.items()}

    def train_step_d(
        self,
        s1_t1: Tensor, s2_t1: Tensor, s1_t2: Tensor, s2_t2_ref: Tensor,
        optimizer_d: torch.optim.Optimizer,
    ) -> float:
        """Single discriminator training step."""
        self.discriminator.train()
        optimizer_d.zero_grad()
        with torch.no_grad():
            I_gen = self.generator(s1_t1, s2_t1, s1_t2)
        cond = torch.cat([s1_t1, s1_t2, s2_t1], dim=1)
        l_d = self.loss_fn.discriminator_loss(self.discriminator, s2_t2_ref, I_gen, cond)
        l_d.backward()
        optimizer_d.step()
        return l_d.item()


def build_optimizers(model: CRGenNet, cfg: CRGenNetConfig):
    """
    Adam optimizers as used in paper:
      Generator: lr=0.001, ReduceLROnPlateau (factor=0.5, patience=10 epochs)
      Discriminator: lr=0.001, weight_decay=1e-5
    """
    opt_g = torch.optim.Adam(
        model.generator.parameters(),
        lr=cfg.lr_g, betas=(cfg.beta1, cfg.beta2)
    )
    opt_d = torch.optim.Adam(
        model.discriminator.parameters(),
        lr=cfg.lr_d, betas=(cfg.beta1, cfg.beta2),
        weight_decay=1e-5
    )
    sched_g = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt_g, mode='min', factor=0.5, patience=10, verbose=True
    )
    return opt_g, opt_d, sched_g


# ─── SECTION 12: Smoke Test ────────────────────────────────────────────────────

class SyntheticSentinelBatch:
    """
    Synthetic SAR-optical data generator for smoke testing.
    Replace with real TCSEN12 or DTSEN1-2 dataloaders for actual training.

    For real data:
      Dataset: https://github.com/chenxiduan/MultiTemporalCloudFree
      GEE script for Sentinel-1/-2 download: see paper supplementary material
      Recommended framework: PyTorch Lightning DataModule
    """
    def __init__(self, B: int = 2, H: int = 64, W: int = 64, cfg: CRGenNetConfig = None):
        self.B = B
        self.H = H
        self.W = W
        self.cfg = cfg or CRGenNetConfig()

    def get_batch(self, device):
        s1_ch = self.cfg.in_channels_s1
        s2_ch = self.cfg.in_channels_s2
        B, H, W = self.B, self.H, self.W
        return {
            's1_t1': torch.randn(B, s1_ch, H, W, device=device) * 0.3,  # SAR backscatter
            's2_t1': torch.randn(B, s2_ch, H, W, device=device) * 0.2,  # Contaminated optical
            's1_t2': torch.randn(B, s1_ch, H, W, device=device) * 0.3,  # Target date SAR
            's2_t2': torch.randn(B, s2_ch, H, W, device=device) * 0.2,  # Reference (ground truth)
        }


def run_smoke_test(steps: int = 3, device_str: str = "cpu"):
    device = torch.device(device_str)
    print("=" * 60)
    print("  CRGenNet — Full Architecture Smoke Test")
    print("=" * 60)
    torch.manual_seed(42)

    # Tiny config for fast test
    cfg = CRGenNetConfig(base_ch=16, swin_embed_dim=16, num_swin_blocks=2)
    model = CRGenNet(cfg).to(device)

    n_params_g = sum(p.numel() for p in model.generator.parameters()) / 1e6
    n_params_d = sum(p.numel() for p in model.discriminator.parameters()) / 1e6
    print(f"\n  Generator:     {n_params_g:.2f}M parameters")
    print(f"  Discriminator: {n_params_d:.2f}M parameters")
    print(f"  Total:         {n_params_g + n_params_d:.2f}M parameters")

    # [1/3] Forward pass check
    print("\n[1/3] Generator forward pass...")
    data_gen = SyntheticSentinelBatch(B=1, H=64, W=64, cfg=cfg)
    batch = data_gen.get_batch(device)
    model.generator.eval()
    with torch.no_grad():
        out = model.generate(batch['s1_t1'], batch['s2_t1'], batch['s1_t2'])
    print(f"  Input:  s1_t1={tuple(batch['s1_t1'].shape)}, s2_t1={tuple(batch['s2_t1'].shape)}, s1_t2={tuple(batch['s1_t2'].shape)}")
    print(f"  Output: {tuple(out.shape)}  range=[{out.min():.2f}, {out.max():.2f}]")
    assert out.shape == (1, cfg.out_channels, 64, 64), "Output shape mismatch!"
    assert -1.1 <= out.min().item() and out.max().item() <= 1.1, "Output out of [-1,1] range!"

    # [2/3] Loss check
    print("\n[2/3] Loss function check...")
    gen_img = torch.randn(1, cfg.out_channels, 64, 64, device=device)
    ref_img = torch.randn(1, cfg.out_channels, 64, 64, device=device)
    loss_fn = CRGenNetLoss(cfg)
    l_sim = loss_fn.similarity_loss(gen_img, ref_img)
    l_cs  = cosine_similarity_loss(gen_img, ref_img)
    l_ms  = ms_ssim_loss(gen_img, ref_img)
    print(f"  L_sim={l_sim.item():.4f}  L_CS={l_cs.item():.4f}  L_MS-SSIM={l_ms.item():.4f}")

    # [3/3] Training loop
    print(f"\n[3/3] Training loop ({steps} steps)...")
    model.train()
    opt_g, opt_d, sched_g = build_optimizers(model, cfg)
    data_gen = SyntheticSentinelBatch(B=2, H=64, W=64, cfg=cfg)

    for step in range(steps):
        batch = data_gen.get_batch(device)
        s1t1, s2t1, s1t2, s2t2 = batch['s1_t1'], batch['s2_t1'], batch['s1_t2'], batch['s2_t2']

        # Discriminator update
        l_d = model.train_step_d(s1t1, s2t1, s1t2, s2t2, opt_d)

        # Generator update
        g_losses = model.train_step_g(s1t1, s2t1, s1t2, s2t2, opt_g)

        print(f"  Step {step+1}/{steps} | L_D={l_d:.4f} | L_G={g_losses['total']:.4f} | L_sim={g_losses['sim']:.4f}")

    print("\n" + "=" * 60)
    print("✓  All checks passed. CRGenNet is ready for training.")
    print("=" * 60)
    print("""
Next steps:
  1. Download the TCSEN12 dataset:
       https://github.com/chenxiduan/MultiTemporalCloudFree
  2. Install required packages:
       pip install torch torchvision einops timm
  3. For full Swin Transformer blocks (recommended for production):
       from timm.models.swin_transformer import SwinTransformerBlock
  4. Scale to paper configuration:
       cfg = CRGenNetConfig(base_ch=64, swin_embed_dim=96)
  5. Train on NVIDIA A40 (as in paper) with batch=8, 200 epochs:
       python train.py --config tcsen12.yaml
  6. For multispectral input (all 12 Sentinel-2 bands):
       cfg = CRGenNetConfig(in_channels_s2=12, out_channels=12)
""")
    return model


if __name__ == "__main__":
    run_smoke_test(steps=3, device_str="cpu")

Read the Full Paper & Get the Dataset

The complete study — including full ablation tables, feature map visualizations from the FusionAttention module, and results across woodland, urban, farmland, and village scenes — is published open-access in the ISPRS Journal of Photogrammetry and Remote Sensing.

Academic Citation:
Duan, C., Belgiu, M., & Stein, A. (2026). High-quality cloud-free optical image generation using multi-temporal SAR and contaminated optical data. ISPRS Journal of Photogrammetry and Remote Sensing, 236, 255–272. https://doi.org/10.1016/j.isprsjprs.2026.03.042

This article is an independent editorial analysis of peer-reviewed open-access research. The PyTorch implementation is an educational adaptation intended to illustrate the paper’s architectural concepts. The original authors trained on a single NVIDIA A40 GPU with batch size 8 for 200 epochs using the Adam optimizer with ReduceLROnPlateau scheduling. For production deployments on real Sentinel-1/-2 data, use the full multispectral band configuration and the official TCSEN12 dataset. All Sentinel data is copernicus open-access imagery.

Leave a Comment

Your email address will not be published. Required fields are marked *

Follow by Email
Tiktok