Shallow Features Matter: How HMHI-Net Fixes the Fundamental Flaw in Video Object Segmentation Memory
Fudan University researchers discovered that every existing memory-based video segmentation method makes the same mistake — memorizing only high-level semantic features while discarding the fine-grained pixel details that unsupervised segmentation desperately needs. Their fix, HMHI-Net, reaches 89.8% J&F on DAVIS-16, outperforming state-of-the-art by 1.6% with a surprisingly simple insight.
Imagine watching a video of a horse galloping through a field and being asked to draw a precise mask around the horse in every frame — without anyone telling you what a horse looks like. That’s exactly what Unsupervised Video Object Segmentation (UVOS) asks an AI to do. The key challenge is not identifying what moves, but doing so with pixel-perfect precision when there’s no prior mask to guide you. Memory mechanisms — where the model stores information from past frames to inform current predictions — seemed like the natural solution. But as researchers at Fudan University demonstrate in this paper, every existing memory-based UVOS approach shares the same blind spot: they only remember high-level semantic blobs, discarding the rich pixel-level details living in the shallow encoder layers. HMHI-Net fixes this by remembering both, and the performance jump speaks for itself.
The Blind Spot Nobody Was Talking About
To understand why this matters, you need to understand how encoder networks process images. A typical hierarchical encoder (like SegFormer’s MiT backbone) processes an image through four progressively deeper layers. At each layer, the spatial resolution drops and the feature channels grow more abstract. By the fourth layer, the feature map is small, compact, and highly semantic — it “knows” there’s a horse in the scene, but it’s lost most of the boundary, texture, and precise location information.
The paper visualizes this beautifully with attention maps. At encoder levels 1 and 2, the attention spreads across the general foreground pixels of the object — the model is paying attention to the whole outline of the horse, the hooves, the mane. By levels 3 and 4, the attention collapses to just a handful of tokens representing the most semantically representative points — essentially, the model is now only looking at “the horse-ness,” not “where the horse actually is.”
Previous memory-based UVOS methods — like PMN, DPA, and TGFormer — all store these high-level level-4 features in their memory banks. When the current frame arrives, the model looks up similar high-level patterns from memory and uses them to guide segmentation. The problem is that the high-level memory can tell the decoder “there should be a horse-like object roughly here,” but it can’t give precise boundary information. In the SVOS (semi-supervised) setting, this isn’t catastrophic because you already have a pixel-accurate mask of the first frame as guidance. In UVOS, you have nothing — the model is starting blind, and high-level memory alone cannot compensate for that fundamental absence.
Memory at encoder level 2 improves UVOS performance by +0.8% J&F. Memory at encoder level 4 (the standard approach) only improves by +0.2% J&F. The second-layer features preserve fine-grained pixel layout information that the fourth-layer features have already discarded — and that pixel information is exactly what UVOS is missing. The solution is to maintain both memory banks simultaneously and let them inform each other.
The Full HMHI-Net Architecture
HMHI-Net builds on this insight with three interconnected components: a hierarchical memory structure, a heterogeneous interaction mechanism between the two memory levels, and a memory update strategy that stores predicted masks along with features.
INPUT: RGB Frame I_t (H×W×3) + Optical Flow O_t (H×W×3)
│
┌────────▼──────────────────────────────────────────────────┐
│ HIERARCHICAL BACKBONE (MiT-b1 / SegFormer) │
│ │
│ Level 1 → F¹_t: H/4 × W/4 × C1 (pixel details) │
│ Level 2 → F²_t: H/8 × W/8 × C2 ← SHALLOW MEMORY │
│ Level 3 → F³_t: H/16 × W/16 × C3 (not memorized) │
│ Level 4 → F⁴_t: H/32 × W/32 × C4 ← HIGH-LEVEL MEMORY │
│ │
│ Merged: F^i_t = I^i_t + O^i_t (image + flow addition) │
└────────┬──────────────────────────────────────────────────┘
│
┌────────▼──────────────────────────────────────────────────┐
│ HIERARCHICAL MEMORY REFINEMENT │
│ │
│ Pixel Memory: F²'_t = Mem_Refine(F²_t, R²_{t-1}) │
│ Semantic Memory: F⁴'_t = Mem_Refine(F⁴_t, R⁴_{t-1}) │
│ │
│ Each Mem_Refine = Self-Attn → Cross-Attn(memory) → FFN │
└────────┬──────────────────────────────────────────────────┘
│
┌────────▼──────────────────────────────────────────────────┐
│ HETEROGENEOUS INTERACTION │
│ │
│ SGIM: F²''_t = SGIM(F⁴'_t, F²'_t) [high → shallow] │
│ Global attention: Q from F²', K,V from F⁴_tmp │
│ │
│ PLAM: F⁴''_t = PLAM(F²'_t, F⁴'_t) [shallow → high] │
│ ConvReLU downsample → concat → ChannelAttn │
│ → SpatialAttn → FFN │
└────────┬──────────────────────────────────────────────────┘
│ Updated features: [F¹_t, F²''_t, F³_t, F⁴''_t]
┌────────▼──────────────────────────────────────────────────┐
│ HIERARCHICAL DECODER (multi-scale, bottom-up fusion) │
│ → M_Pred ∈ R^{H×W×1} (segmentation mask) │
└────────┬──────────────────────────────────────────────────┘
│
┌────────▼──────────────────────────────────────────────────┐
│ MEMORY UPDATE (sliding window, size N=5) │
│ R²_t = Mem_Update(R²_{t-1}, F²''_t, M_Pred) │
│ R⁴_t = Mem_Update(R⁴_{t-1}, F⁴''_t, M_Pred) │
└────────────────────────────────────────────────────────────┘
Memory Refinement: How Past Frames Inform the Present
The memory refinement process is the same for both the shallow (level-2) and high-level (level-4) banks, which is an elegant design choice that prevents misalignment between the two. For the current frame feature F^i_t, the module first applies self-attention to strengthen the feature’s own internal representations:
Then the cross-attention between the current frame and all stored reference frames in the memory bank extracts the relevant history. The correlation score Scorr measures how similar each pixel in the current frame is to all pixels across all T memorized frames, then pulls weighted information from the memory values:
PLAM: Injecting Pixel Detail Into Semantic Features
The Pixel-guided Local Alignment Module (PLAM) enriches the high-level feature F⁴’_t with the structural information living in the shallow feature F²’_t. The design respects the spatial nature of shallow features — rather than applying global attention (which would muddy fine-grained spatial relationships), PLAM uses a position-preserving approach.
First, F²’_t is downsampled to match the spatial resolution of F⁴’_t through ConvReLU operations. The aligned shallow feature is then directly concatenated with F⁴’_t along the channel dimension, creating a unified 2C-dimensional representation. Sequential channel attention and spatial attention sharpen which channels and which locations are most informative, and an FFN projects back to the original feature space. The entire operation maintains spatial coherence — nearby shallow pixels inform nearby high-level tokens, which is why “local alignment” is in the name.
SGIM: Injecting Semantic Understanding Into Pixel Features
The Semantic-guided Global Integration Module (SGIM) works in the opposite direction — pulling semantic context from F⁴’_t down into the shallow feature F²’_t. This prevents the common problem where fine-grained details are lost because the decoder never receives semantic guidance at the pixel level.
SGIM uses a global cross-attention strategy: queries come from F²’_t (the pixel features), while keys and values come from the linearly projected F⁴_tmp (semantic features upsampled to match the shallow feature’s channel dimension). This allows every pixel in the shallow feature map to attend globally to the semantic content, pulling in object-level understanding without being constrained by local neighborhoods. The “global” in SGIM’s name reflects this broader receptive field compared to PLAM’s local alignment.
The ablation study directly proves that swapping PLAM and SGIM hurts performance. Using SGIM for shallow-to-high (S2H) drops DAVIS-16 J&F by 0.3%. Using PLAM for high-to-shallow (H2S) drops it by 0.4%. The heterogeneous design — local alignment for pixel injection, global attention for semantic injection — reflects the fundamentally different natures of what each feature type needs from the other.
Results: State of the Art Across Every Benchmark
UVOS Performance
| Method | Venue | Backbone | DAVIS-16 J&F | FBMS J | YouTube-Obj J |
|---|---|---|---|---|---|
| GFA | AAAI’24 | — | 88.2 | 82.4 | 74.7 |
| GSA | CVPR’24 | ResNet-101 | 87.7 | 83.1 | — |
| DPA | CVPR’24 | VGG-16 | 87.6 | 83.4 | 73.7 |
| SimulFlow | ACMMM’23 | MiT-b1 | 87.4 | 80.4 | 72.9 |
| OAST | ICCV’23 | MobileViT | 87.0 | 83.0 | — |
| HMHI-Net (ours) | ACMMM’25 | MiT-b1 | 89.8 (+1.6) | 86.9 (+3.5) | 76.2 (+1.5) |
Table: UVOS benchmark comparison. Gains shown relative to previous best. All methods use optical flow.
The FBMS gain of 3.5% is particularly striking — FBMS contains challenging multi-object scenes with significant motion complexity, and the shallow memory’s ability to preserve fine-grained boundary information matters enormously when multiple moving objects must be simultaneously tracked.
Ablation: What Each Component Contributes
| Variant | DAVIS-16 J&F | FBMS J | YTB-Obj J | FPS | Params (M) |
|---|---|---|---|---|---|
| Baseline | 88.4 | 84.7 | 75.1 | 34.4 | 36.7 |
| + Hierarchical Memory | 89.3 (+0.9) | 86.0 (+1.3) | 76.1 (+1.0) | 27.8 | 47.7 |
| + PLAM (S2H) | 89.4 (+1.0) | 86.3 (+1.6) | 75.6 | 27.3 | 60.0 |
| + SGIM (H2S) | 89.7 (+1.3) | 86.5 (+1.8) | 75.3 | 26.9 | 48.4 |
| HMHI-Net (full) | 89.8 (+1.4) | 86.9 (+2.1) | 76.2 (+1.1) | 26.2 | 60.8 |
Cross-Backbone Robustness
One of the most reassuring results in the paper is that HMHI-Net’s improvements are not tied to a specific backbone. Across MiT-b1, MiT-b2, MiT-b3, and Swin-Tiny, the hierarchical memory and heterogeneous interaction consistently deliver gains ranging from +1.0% to +1.6% on DAVIS-16. The swin_tiny backbone sees a particularly large +2.3% improvement on YouTube-Objects, suggesting the shallow memory mechanism is especially helpful when the backbone’s feature pyramid is less well-calibrated for pixel-level tasks.
“UVOS inherently suffers from the deficiency of lacking fine-grained information due to the absence of pixel-level prior knowledge. High-level memory alone can hardly compensate for this intrinsic absence.” — Zheng, He, Li, Li & Zhang, ACM Multimedia 2025
Training Details
HMHI-Net is pre-trained on YouTube-VOS for 150 epochs with AdamW (lr=6e-5), then fine-tuned on DAVIS-16 or FBMS with a cosine annealing scheduler (lr=1e-4) until convergence. The training loss combines binary cross-entropy, focal loss, and Dice loss — a combination specifically chosen to handle the foreground/background imbalance inherent in video segmentation. Training sequences consist of five frames with k=1 (update every frame) and T=5 (keep 5 reference frames in memory). Importantly, the first frame in each sequence skips memory refinement entirely since no reference frames exist yet — it only populates both memory banks for future use. Everything runs on four NVIDIA RTX 4090 GPUs at 512×512 resolution.
Complete End-to-End HMHI-Net Implementation (PyTorch)
The implementation below is a complete, runnable PyTorch translation of HMHI-Net, organized into 11 sections that map directly to the paper. It covers the SegFormer-based hierarchical encoder, the dual memory banks (shallow + high-level), the memory refinement module (self-attention + cross-attention), PLAM (pixel-guided local alignment), SGIM (semantic-guided global integration), the hierarchical decoder, the sliding-window memory update with reference encoder, the combined BCE + focal + Dice training loss, a synthetic video dataset loader, the full training loop with pre-training and fine-tuning stages, and a smoke test.
# ==============================================================================
# HMHI-Net: Hierarchical Memory with Heterogeneous Interaction
# for Unsupervised Video Object Segmentation
# Paper: arXiv:2507.22465v1 | ACM Multimedia 2025
# Authors: Xiangyu Zheng, Songcheng He, Wanyun Li, Xiaoqiang Li, Wei Zhang
# Affiliation: Fudan University / Shanghai University
# Project: https://github.com/ZhengxyFlow/HMHI-Net
# ==============================================================================
# Sections:
# 1. Imports & Configuration
# 2. Hierarchical Encoder (SegFormer-style MiT-b1)
# 3. Memory Refinement Module (Self-Attn + Cross-Attn + FFN)
# 4. PLAM: Pixel-guided Local Alignment Module
# 5. SGIM: Semantic-guided Global Integration Module
# 6. Hierarchical Memory Banks (shallow + high-level)
# 7. Hierarchical Decoder (multi-scale FPN-style)
# 8. Reference Encoder for Memory Update
# 9. Full HMHI-Net Model
# 10. Loss Functions (BCE + Focal + Dice)
# 11. Dataset, Training Loop & Smoke Test
# ==============================================================================
from __future__ import annotations
import math
import warnings
from collections import deque
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
warnings.filterwarnings("ignore")
# ─── SECTION 1: Configuration ─────────────────────────────────────────────────
class HMHIConfig:
"""
HMHI-Net configuration. Defaults match the paper's MiT-b1 setting.
"""
# Image dimensions
img_h: int = 512
img_w: int = 512
# Encoder channel dimensions per level (MiT-b1)
enc_channels: List[int] = None # [64, 128, 320, 512]
# Decoder
dec_channels: int = 256
# Memory configuration
memory_size: int = 5 # sliding window T=5
memory_update_freq: int = 1 # update every k=1 frames
# Attention
num_heads: int = 8
head_dim: int = 64
# Training
lr_pretrain: float = 6e-5 # YouTube-VOS pre-training
lr_finetune: float = 1e-4 # DAVIS-16 fine-tuning
epochs_pretrain: int = 150
seq_len: int = 5 # frames per training sequence
def __init__(self, **kwargs):
self.enc_channels = [64, 128, 320, 512]
for k, v in kwargs.items():
setattr(self, k, v)
# ─── SECTION 2: Hierarchical Encoder ──────────────────────────────────────────
class DWConv(nn.Module):
"""Depth-wise convolution for efficient local feature mixing in MiT blocks."""
def __init__(self, dim: int):
super().__init__()
self.dw = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
def forward(self, x: Tensor, H: int, W: int) -> Tensor:
B, N, C = x.shape
x = x.transpose(1, 2).reshape(B, C, H, W)
x = self.dw(x).flatten(2).transpose(1, 2)
return x
class MixFFN(nn.Module):
"""Mix-FFN from SegFormer: Linear + DWConv + GELU + Linear."""
def __init__(self, dim: int, expand: int = 4):
super().__init__()
self.fc1 = nn.Linear(dim, dim * expand)
self.dw = DWConv(dim * expand)
self.act = nn.GELU()
self.fc2 = nn.Linear(dim * expand, dim)
self.drop = nn.Dropout(0.1)
def forward(self, x: Tensor, H: int, W: int) -> Tensor:
x = self.fc1(x)
x = self.dw(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class EfficientSelfAttn(nn.Module):
"""
Efficient self-attention from SegFormer: reduces sequence length
via spatial reduction ratio for lower memory/compute cost.
"""
def __init__(self, dim: int, num_heads: int = 8, sr_ratio: int = 1):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.q = nn.Linear(dim, dim)
self.kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x: Tensor, H: int, W: int) -> Tensor:
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_r = x.transpose(1, 2).reshape(B, C, H, W)
x_r = self.sr(x_r).flatten(2).transpose(1, 2)
x_r = self.norm(x_r)
else:
x_r = x
kv = self.kv(x_r).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
return self.proj(x)
class MiTBlock(nn.Module):
"""Single Mix Transformer block: LayerNorm → EfficientAttn → LayerNorm → MixFFN."""
def __init__(self, dim: int, num_heads: int = 8, sr_ratio: int = 1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = EfficientSelfAttn(dim, num_heads, sr_ratio)
self.norm2 = nn.LayerNorm(dim)
self.ffn = MixFFN(dim)
def forward(self, x: Tensor, H: int, W: int) -> Tensor:
x = x + self.attn(self.norm1(x), H, W)
x = x + self.ffn(self.norm2(x), H, W)
return x
class MiTStage(nn.Module):
"""
One stage of the Mix Transformer encoder.
Overlapping patch embedding → stack of MiT blocks.
"""
def __init__(
self,
in_ch: int,
out_ch: int,
num_blocks: int = 2,
num_heads: int = 2,
sr_ratio: int = 8,
patch_size: int = 7,
stride: int = 4,
):
super().__init__()
# Overlapping patch embed
self.patch_embed = nn.Conv2d(
in_ch, out_ch, kernel_size=patch_size,
stride=stride, padding=patch_size // 2
)
self.norm = nn.LayerNorm(out_ch)
self.blocks = nn.ModuleList([
MiTBlock(out_ch, num_heads, sr_ratio)
for _ in range(num_blocks)
])
def forward(self, x: Tensor) -> Tuple[Tensor, int, int]:
"""Returns (tokens, H_out, W_out)."""
x = self.patch_embed(x)
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # (B, H*W, C)
x = self.norm(x)
for blk in self.blocks:
x = blk(x, H, W)
return x, H, W
class HierarchicalEncoder(nn.Module):
"""
Simplified MiT-b1 style encoder (Section 3.1, Eq. 1).
Takes image I_t and optical flow O_t, extracts 4-level features,
adds them element-wise to produce merged features F^i_t.
Returns F1, F2, F3, F4 — each as (B, H_i*W_i, C_i) token sequences.
"""
def __init__(self, in_ch: int = 3, channels: List[int] = None):
super().__init__()
C = channels or [64, 128, 320, 512]
# Image encoder stages
self.img_stage1 = MiTStage(in_ch, C[0], 2, 1, 8, 7, 4)
self.img_stage2 = MiTStage(C[0], C[1], 2, 2, 4, 3, 2)
self.img_stage3 = MiTStage(C[1], C[2], 2, 5, 2, 3, 2)
self.img_stage4 = MiTStage(C[2], C[3], 2, 8, 1, 3, 2)
# Optical flow encoder (same architecture, shared structure)
self.flow_stage1 = MiTStage(in_ch, C[0], 2, 1, 8, 7, 4)
self.flow_stage2 = MiTStage(C[0], C[1], 2, 2, 4, 3, 2)
self.flow_stage3 = MiTStage(C[1], C[2], 2, 5, 2, 3, 2)
self.flow_stage4 = MiTStage(C[2], C[3], 2, 8, 1, 3, 2)
self.channels = C
def forward(
self, image: Tensor, flow: Tensor
) -> Tuple[Tensor, Tensor, Tensor, Tensor, List[Tuple[int,int]]]:
"""
image: (B, 3, H, W)
flow: (B, 3, H, W)
Returns F1, F2, F3, F4 — merged features (image + flow)
and spatial sizes [(H1,W1), (H2,W2), (H3,W3), (H4,W4)]
"""
# Reshape from 2D feature maps after each stage
def to_2d(x, H, W): return x.transpose(1,2).reshape(x.shape[0],-1,H,W)
i1, H1, W1 = self.img_stage1(image)
o1, _, _ = self.flow_stage1(flow)
F1 = i1 + o1
i2, H2, W2 = self.img_stage2(to_2d(i1, H1, W1))
o2, _, _ = self.flow_stage2(to_2d(o1, H1, W1))
F2 = i2 + o2
i3, H3, W3 = self.img_stage3(to_2d(i2, H2, W2))
o3, _, _ = self.flow_stage3(to_2d(o2, H2, W2))
F3 = i3 + o3
i4, H4, W4 = self.img_stage4(to_2d(i3, H3, W3))
o4, _, _ = self.flow_stage4(to_2d(o3, H3, W3))
F4 = i4 + o4
sizes = [(H1,W1),(H2,W2),(H3,W3),(H4,W4)]
return F1, F2, F3, F4, sizes
# ─── SECTION 3: Memory Refinement Module ──────────────────────────────────────
class MemoryRefinement(nn.Module):
"""
Unified memory readout mechanism (Section 3.2, Eq. 5–7).
Applies the same architecture to both shallow (level-2) and
high-level (level-4) memory banks to avoid misalignment.
Three steps:
1. Self-attention on current feature (Eq. 5) — internal coherence
2. Cross-attention with memory bank (Eq. 6–7) — temporal retrieval
3. Feed-forward network — realigns feature representation space
dim: channel dimension of the features
T: max memory frames (for scaling KV length)
"""
def __init__(self, dim: int, num_heads: int = 8):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
# Self-attention projections
self.sa_qkv = nn.Linear(dim, dim * 3)
self.sa_proj = nn.Linear(dim, dim)
self.sa_norm = nn.LayerNorm(dim)
# Cross-attention with memory projections
self.mem_q = nn.Linear(dim, dim)
self.mem_kv = nn.Linear(dim, dim * 2)
self.mem_proj = nn.Linear(dim, dim)
self.mem_norm = nn.LayerNorm(dim)
# FFN to realign back to backbone representation space
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim),
)
self.ffn_norm = nn.LayerNorm(dim)
def _attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
"""Scaled dot-product attention."""
B, H, N, D = q.shape
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, N, M)
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, N, -1)
return out
def forward(self, F_t: Tensor, R_mem: Optional[Tensor] = None) -> Tensor:
"""
F_t: (B, N, C) — current frame feature
R_mem: (B, T*N, C) — reference memory (None for first frame)
Returns refined feature F'_t: (B, N, C)
"""
B, N, C = F_t.shape
# Step 1: Self-attention (Eq. 5)
x = self.sa_norm(F_t)
qkv = self.sa_qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2,0,3,1,4)
q, k, v = qkv[0], qkv[1], qkv[2]
F_t = F_t + self.sa_proj(self._attn(q, k, v))
# Step 2: Cross-attention with memory (Eq. 6–7)
if R_mem is not None:
x = self.mem_norm(F_t)
q_mem = self.mem_q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)
kv_mem = self.mem_kv(R_mem).reshape(B, R_mem.size(1), 2, self.num_heads, self.head_dim)
kv_mem = kv_mem.permute(2,0,3,1,4)
k_mem, v_mem = kv_mem[0], kv_mem[1]
F_t = F_t + self.mem_proj(self._attn(q_mem, k_mem, v_mem))
# Step 3: FFN realignment
F_t = F_t + self.ffn(self.ffn_norm(F_t))
return F_t
# ─── SECTION 4: PLAM — Pixel-guided Local Alignment Module ────────────────────
class CBAM(nn.Module):
"""
Simplified CBAM-style sequential channel + spatial attention.
Used inside PLAM to re-emphasize informative channels and locations.
"""
def __init__(self, channels: int, reduction: int = 16):
super().__init__()
# Channel attention
self.ca_avg = nn.AdaptiveAvgPool2d(1)
self.ca_max = nn.AdaptiveMaxPool2d(1)
self.ca_fc = nn.Sequential(
nn.Linear(channels, channels // reduction),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels),
)
# Spatial attention
self.sa_conv = nn.Conv2d(2, 1, kernel_size=7, padding=3)
def forward(self, x: Tensor) -> Tensor:
"""x: (B, C, H, W)"""
# Channel attention
B, C, H, W = x.shape
ca_avg = self.ca_avg(x).flatten(1)
ca_max = self.ca_max(x).flatten(1)
ca_w = torch.sigmoid(self.ca_fc(ca_avg) + self.ca_fc(ca_max))
x = x * ca_w.reshape(B, C, 1, 1)
# Spatial attention
sa_in = torch.cat([x.mean(dim=1, keepdim=True), x.max(dim=1, keepdim=True)[0]], dim=1)
sa_w = torch.sigmoid(self.sa_conv(sa_in))
x = x * sa_w
return x
class PLAM(nn.Module):
"""
Pixel-guided Local Alignment Module (Section 3.3, Eq. 8–10).
Performs SHALLOW → HIGH level feature enrichment.
Injects fine-grained structural information from F2' into F4'.
Strategy: position-preserving — spatial downsampling of shallow
features → direct concatenation → CBAM → FFN.
This preserves spatial coherence: nearby shallow pixels inform
nearby high-level tokens (LOCAL alignment, not global).
Inputs:
F2_prime: (B, H2*W2, C2) — shallow memory-refined features
F4_prime: (B, H4*W4, C4) — high-level memory-refined features
Output:
F4_double_prime: (B, H4*W4, C4) — pixel-enriched high-level features
"""
def __init__(self, c2: int, c4: int):
super().__init__()
# Downsample F2' to match F4' spatial resolution + align channels
self.downsample = nn.Sequential(
nn.Conv2d(c2, c4, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(c4, c4, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
)
self.linear_proj = nn.Linear(c4, c4)
# Sequential channel + spatial attention on concatenated [F4, F2_tmp]
self.cbam = CBAM(channels=c4 * 2)
# FFN to project back to C4 space (Eq. 10)
self.ffn = nn.Sequential(
nn.Linear(c4 * 2, c4 * 4),
nn.GELU(),
nn.Linear(c4 * 4, c4),
)
self.norm = nn.LayerNorm(c4)
def forward(
self,
F2_prime: Tensor, # (B, H2*W2, C2) shallow refined
F4_prime: Tensor, # (B, H4*W4, C4) high-level refined
hw2: Tuple[int,int], # (H2, W2)
hw4: Tuple[int,int], # (H4, W4)
) -> Tensor:
B, N4, C4 = F4_prime.shape
H2, W2 = hw2
H4, W4 = hw4
# Eq. 8: ConvReLU downsample F2' to match F4' resolution
F2_2d = F2_prime.transpose(1, 2).reshape(B, -1, H2, W2)
F2_tmp = self.downsample(F2_2d) # (B, C4, H4, W4)
F2_tmp = F2_tmp.flatten(2).transpose(1, 2) # (B, N4, C4)
F2_tmp = self.linear_proj(F2_tmp)
# Eq. 9: Concatenate and apply CBAM
F4_cat = torch.cat([F4_prime, F2_tmp], dim=-1) # (B, N4, 2C4)
F4_2d = F4_cat.transpose(1, 2).reshape(B, C4 * 2, H4, W4)
F4_2d = self.cbam(F4_2d) # channel + spatial attn
F4_cat = F4_2d.flatten(2).transpose(1, 2)
# Eq. 10: FFN back to original C4 space
F4_out = self.ffn(F4_cat)
F4_out = self.norm(F4_out + F4_prime) # residual
return F4_out
# ─── SECTION 5: SGIM — Semantic-guided Global Integration Module ───────────────
class SGIM(nn.Module):
"""
Semantic-guided Global Integration Module (Section 3.3, Eq. 11–12).
Performs HIGH → SHALLOW level feature enrichment.
Injects global semantic cues from F4' into F2'.
Strategy: global cross-attention — Q from F2' (pixel queries),
K,V from linearly projected F4' (semantic context).
Every pixel in the shallow map can attend to the FULL set of
semantic tokens (GLOBAL perception, not local).
Inputs:
F4_prime: (B, H4*W4, C4) — high-level memory-refined features
F2_prime: (B, H2*W2, C2) — shallow memory-refined features
Output:
F2_double_prime: (B, H2*W2, C2) — semantically enriched shallow features
"""
def __init__(self, c2: int, c4: int, num_heads: int = 8):
super().__init__()
self.num_heads = num_heads
self.head_dim = c2 // num_heads
self.scale = self.head_dim ** -0.5
# Align F4' channel dim to C2 (Eq. 11)
self.linear_proj = nn.Linear(c4, c2)
# Self-attention to strengthen pixel-level inner relations (Eq. 11)
self.sa_qkv = nn.Linear(c2, c2 * 3)
self.sa_proj = nn.Linear(c2, c2)
self.sa_norm = nn.LayerNorm(c2)
# Global cross-attention: Q from F2', K,V from F4_tmp (Eq. 12)
self.ca_q = nn.Linear(c2, c2)
self.ca_kv = nn.Linear(c2, c2 * 2)
self.ca_proj = nn.Linear(c2, c2)
self.ca_norm = nn.LayerNorm(c2)
# FFN to complete fusion
self.ffn = nn.Sequential(
nn.Linear(c2, c2 * 4),
nn.GELU(),
nn.Linear(c2 * 4, c2),
)
self.ffn_norm = nn.LayerNorm(c2)
def _attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
"""Multi-head scaled dot-product attention."""
B, Nq, _ = q.shape
q = q.reshape(B, Nq, self.num_heads, self.head_dim).permute(0,2,1,3)
k = k.reshape(B, k.size(1), self.num_heads, self.head_dim).permute(0,2,1,3)
v = v.reshape(B, v.size(1), self.num_heads, self.head_dim).permute(0,2,1,3)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, Nq, -1)
return out
def forward(
self,
F4_prime: Tensor, # (B, H4*W4, C4) high-level semantic features
F2_prime: Tensor, # (B, H2*W2, C2) shallow pixel features
) -> Tensor:
B, N4, C4 = F4_prime.shape
B, N2, C2 = F2_prime.shape
# Eq. 11: Linear project F4' to C2 dimension → F4_tmp
F4_tmp = self.linear_proj(F4_prime) # (B, N4, C2)
# Eq. 11: Self-attention on F2' to strengthen pixel-level relations
x = self.sa_norm(F2_prime)
qkv = self.sa_qkv(x).reshape(B, N2, 3, C2).permute(2,0,1,3)
q_sa, k_sa, v_sa = qkv[0], qkv[1], qkv[2]
F2_prime = F2_prime + self.sa_proj(
(F.scaled_dot_product_attention(
q_sa.reshape(B,-1,self.num_heads,self.head_dim).transpose(1,2),
k_sa.reshape(B,-1,self.num_heads,self.head_dim).transpose(1,2),
v_sa.reshape(B,-1,self.num_heads,self.head_dim).transpose(1,2),
)).transpose(1,2).reshape(B, N2, C2)
)
# Eq. 12: Global cross-attention — Q from F2', K,V from F4_tmp
x = self.ca_norm(F2_prime)
q_ca = self.ca_q(x) # (B, N2, C2)
kv_ca = self.ca_kv(F4_tmp).reshape(B, N4, 2, C2).permute(2,0,1,3)
k_ca, v_ca = kv_ca[0], kv_ca[1]
ca_out = self._attn(q_ca, k_ca, v_ca) # (B, N2, C2)
F2_prime = F2_prime + self.ca_proj(ca_out)
# FFN: project refined shallow features to complete fusion
F2_out = F2_prime + self.ffn(self.ffn_norm(F2_prime))
return F2_out
# ─── SECTION 6: Hierarchical Memory Banks ─────────────────────────────────────
class MemoryBank:
"""
Sliding window memory bank (Section 3.4).
Stores the most recent T reference features.
Updated every k frames using first-in-first-out strategy.
Memory bank R^i_t ∈ R^{T × H_i × W_i × C_i}
"""
def __init__(self, max_size: int = 5):
self.max_size = max_size
self.buffer: deque = deque(maxlen=max_size)
def is_empty(self) -> bool:
return len(self.buffer) == 0
def push(self, feature: Tensor):
"""feature: (B, N, C) — reference feature for one frame."""
self.buffer.append(feature.detach().cpu())
def get_memory(self, device) -> Optional[Tensor]:
"""Returns (B, T*N, C) concatenated memory or None if empty."""
if self.is_empty():
return None
return torch.cat([f.to(device) for f in self.buffer], dim=1)
def reset(self):
self.buffer.clear()
# ─── SECTION 7: Hierarchical Decoder ──────────────────────────────────────────
class FPNDecoder(nn.Module):
"""
Hierarchical FPN-style decoder (Section 3.1, Eq. 4).
Takes multi-scale features [F1, F2'', F3, F4''] and progressively
upsamples from the top level, fusing with lower-level features
in a bottom-up manner to produce the final segmentation mask.
"""
def __init__(self, channels: List[int], dec_ch: int = 256):
super().__init__()
C1, C2, C3, C4 = channels
# Lateral projections — align all levels to dec_ch
self.lat1 = nn.Conv2d(C1, dec_ch, 1)
self.lat2 = nn.Conv2d(C2, dec_ch, 1)
self.lat3 = nn.Conv2d(C3, dec_ch, 1)
self.lat4 = nn.Conv2d(C4, dec_ch, 1)
# Top-down fusion convolutions
self.fuse43 = nn.Sequential(nn.Conv2d(dec_ch * 2, dec_ch, 3, padding=1), nn.ReLU())
self.fuse32 = nn.Sequential(nn.Conv2d(dec_ch * 2, dec_ch, 3, padding=1), nn.ReLU())
self.fuse21 = nn.Sequential(nn.Conv2d(dec_ch * 2, dec_ch, 3, padding=1), nn.ReLU())
# Final output head
self.head = nn.Sequential(
nn.Conv2d(dec_ch, dec_ch // 2, 3, padding=1),
nn.ReLU(),
nn.Conv2d(dec_ch // 2, 1, 1),
)
def forward(
self,
F1: Tensor, F2: Tensor, F3: Tensor, F4: Tensor,
sizes: List[Tuple[int,int]],
) -> Tensor:
"""
F1, F2, F3, F4: token sequences (B, N_i, C_i)
sizes: spatial sizes [(H1,W1), ..., (H4,W4)]
Returns: M_Pred (B, 1, H, W) — sigmoid-normalized mask
"""
(H1,W1),(H2,W2),(H3,W3),(H4,W4) = sizes
B = F1.shape[0]
# Reshape tokens to 2D feature maps
def to2d(x, H, W):
C = x.shape[-1]
return x.transpose(1,2).reshape(B, C, H, W)
p1 = self.lat1(to2d(F1, H1, W1))
p2 = self.lat2(to2d(F2, H2, W2))
p3 = self.lat3(to2d(F3, H3, W3))
p4 = self.lat4(to2d(F4, H4, W4))
# Bottom-up fusion: upsample top → fuse with next level down
p43 = self.fuse43(torch.cat([
F.interpolate(p4, size=(H3, W3), mode='bilinear', align_corners=False), p3
], dim=1))
p32 = self.fuse32(torch.cat([
F.interpolate(p43, size=(H2, W2), mode='bilinear', align_corners=False), p2
], dim=1))
p21 = self.fuse21(torch.cat([
F.interpolate(p32, size=(H1, W1), mode='bilinear', align_corners=False), p1
], dim=1))
# Upsample to original resolution
out = self.head(p21)
out = F.interpolate(out, scale_factor=4, mode='bilinear', align_corners=False)
return torch.sigmoid(out)
# ─── SECTION 8: Reference Encoder for Memory Update ───────────────────────────
class ReferenceEncoder(nn.Module):
"""
Simple memory encoder that integrates the predicted mask M_Pred
into the encoded feature before storing in the memory bank.
Section 3.4, Fig. 3(e).
Two ConvReLU layers → Linear projection → stored as reference.
The mask provides reliable segmentation supervision that helps
future frames attend to object regions in memory.
"""
def __init__(self, feat_ch: int):
super().__init__()
self.mask_embed = nn.Conv2d(1, feat_ch, kernel_size=1)
self.conv1 = nn.Sequential(nn.Conv2d(feat_ch * 2, feat_ch, 3, padding=1), nn.ReLU())
self.conv2 = nn.Sequential(nn.Conv2d(feat_ch, feat_ch, 3, padding=1), nn.ReLU())
self.linear = nn.Linear(feat_ch, feat_ch)
def forward(self, feat: Tensor, mask: Tensor, H: int, W: int) -> Tensor:
"""
feat: (B, N, C) — refined feature
mask: (B, 1, H_orig, W_orig) — predicted mask
Returns: (B, N, C) — mask-integrated reference feature
"""
B, N, C = feat.shape
feat_2d = feat.transpose(1, 2).reshape(B, C, H, W)
# Downsample mask to feature resolution and embed
mask_down = F.interpolate(mask, size=(H, W), mode='bilinear', align_corners=False)
mask_embed = self.mask_embed(mask_down) # (B, C, H, W)
# Fuse feature and mask
fused = self.conv1(torch.cat([feat_2d, mask_embed], dim=1))
fused = self.conv2(fused)
ref = fused.flatten(2).transpose(1, 2) # (B, N, C)
ref = self.linear(ref)
return ref
# ─── SECTION 9: Full HMHI-Net Model ───────────────────────────────────────────
class HMHINet(nn.Module):
"""
HMHI-Net: Hierarchical Memory with Heterogeneous Interaction Network
for Unsupervised Video Object Segmentation (Section 3, Algorithm 1).
Two-level memory architecture:
- Shallow memory (level 2): preserves pixel-level spatial details
- High-level memory (level 4): preserves semantic object consistency
Heterogeneous interaction:
- PLAM: shallow → high (local alignment, position-preserving)
- SGIM: high → shallow (global integration, broad semantic context)
Forward pass for a single frame (called sequentially for each video frame):
1. Encode image + flow → F1, F2, F3, F4
2. Refine F2 and F4 using their respective memory banks (Eq. 5–7)
3. Apply PLAM to enrich F4 with pixel details from F2 (Eq. 8–10)
4. Apply SGIM to enrich F2 with semantic context from F4 (Eq. 11–12)
5. Decode [F1, F2'', F3, F4''] → predicted mask
6. Update both memory banks with refined features + predicted mask
"""
def __init__(self, cfg: Optional[HMHIConfig] = None):
super().__init__()
cfg = cfg or HMHIConfig()
self.cfg = cfg
C = cfg.enc_channels
# Main components
self.encoder = HierarchicalEncoder(in_ch=3, channels=C)
self.mem_refine_shallow = MemoryRefinement(C[1], cfg.num_heads) # level 2
self.mem_refine_high = MemoryRefinement(C[3], cfg.num_heads) # level 4
self.plam = PLAM(c2=C[1], c4=C[3]) # shallow → high
self.sgim = SGIM(c2=C[1], c4=C[3], num_heads=cfg.num_heads) # high → shallow
self.decoder = FPNDecoder(C, cfg.dec_channels)
# Reference encoders for memory update
self.ref_enc_shallow = ReferenceEncoder(C[1])
self.ref_enc_high = ReferenceEncoder(C[3])
# Memory banks (managed externally during inference/training sequences)
self.mem_shallow = MemoryBank(cfg.memory_size)
self.mem_high = MemoryBank(cfg.memory_size)
def reset_memory(self):
"""Reset both memory banks (call at start of each video sequence)."""
self.mem_shallow.reset()
self.mem_high.reset()
def forward(
self,
image: Tensor, # (B, 3, H, W)
flow: Tensor, # (B, 3, H, W)
is_first_frame: bool = False,
) -> Tensor:
"""
Process a single video frame.
If is_first_frame=True:
Skip memory refinement (no reference frames yet).
Only encode and populate both memory banks.
Otherwise:
Refine with memory → heterogeneous interaction → decode → update memory.
Returns: M_Pred (B, 1, H, W) — predicted foreground mask
"""
device = image.device
# Step 1: Encode image + flow → 4-level merged features (Eq. 1)
F1, F2, F3, F4, sizes = self.encoder(image, flow)
(H1,W1),(H2,W2),(H3,W3),(H4,W4) = sizes
if is_first_frame:
# First frame: no memory available — skip refinement
# Still compute a prediction with unrefined features
M_Pred = self.decoder(F1, F2, F3, F4, sizes)
# Initialize memory banks with first frame features + predicted mask
ref_shallow = self.ref_enc_shallow(F2, M_Pred, H2, W2)
ref_high = self.ref_enc_high(F4, M_Pred, H4, W4)
self.mem_shallow.push(ref_shallow)
self.mem_high.push(ref_high)
return M_Pred
# Step 2: Memory refinement with stored reference features (Eq. 5–7)
R_shallow = self.mem_shallow.get_memory(device)
R_high = self.mem_high.get_memory(device)
F2_prime = self.mem_refine_shallow(F2, R_shallow)
F4_prime = self.mem_refine_high(F4, R_high)
# Step 3 & 4: Heterogeneous interaction (Eq. 3, 8–12)
F4_double_prime = self.plam(F2_prime, F4_prime, (H2,W2), (H4,W4)) # S2H
F2_double_prime = self.sgim(F4_prime, F2_prime) # H2S
# Step 5: Decode updated multi-scale features (Eq. 4)
M_Pred = self.decoder(F1, F2_double_prime, F3, F4_double_prime, sizes)
# Step 6: Update memory banks with refined features + mask (Eq. 4)
ref_shallow = self.ref_enc_shallow(F2_double_prime, M_Pred, H2, W2)
ref_high = self.ref_enc_high(F4_double_prime, M_Pred, H4, W4)
self.mem_shallow.push(ref_shallow)
self.mem_high.push(ref_high)
return M_Pred
# ─── SECTION 10: Loss Functions ───────────────────────────────────────────────
class FocalLoss(nn.Module):
"""
Binary focal loss to address class imbalance
between foreground object and background pixels.
"""
def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
bce = F.binary_cross_entropy(pred, target, reduction='none')
p_t = pred * target + (1 - pred) * (1 - target)
focal_w = self.alpha * ((1 - p_t) ** self.gamma)
return (focal_w * bce).mean()
class DiceLoss(nn.Module):
"""
Soft Dice loss for overlap-based optimization.
Particularly effective for imbalanced segmentation maps.
"""
def __init__(self, smooth: float = 1e-5):
super().__init__()
self.smooth = smooth
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
p = pred.flatten(1)
t = target.flatten(1)
intersection = (p * t).sum(dim=1)
dice = (2 * intersection + self.smooth) / (p.sum(dim=1) + t.sum(dim=1) + self.smooth)
return (1 - dice).mean()
class UVOSLoss(nn.Module):
"""
Combined training loss: BCE + Focal + Dice (Section 4.1).
Averaged over all frames in the training sequence.
"""
def __init__(self, w_bce: float = 1.0, w_focal: float = 1.0, w_dice: float = 1.0):
super().__init__()
self.bce = nn.BCELoss()
self.focal = FocalLoss()
self.dice = DiceLoss()
self.w_bce = w_bce
self.w_focal = w_focal
self.w_dice = w_dice
def forward(self, preds: List[Tensor], targets: List[Tensor]) -> Tensor:
"""
preds: list of (B, 1, H, W) sigmoid predictions, one per frame
targets: list of (B, 1, H, W) binary ground truth masks
Returns: scalar loss averaged over all frames
"""
total_loss = 0.0
for pred, target in zip(preds, targets):
target = target.float()
total_loss += (
self.w_bce * self.bce(pred, target)
+ self.w_focal * self.focal(pred, target)
+ self.w_dice * self.dice(pred, target)
)
return total_loss / max(1, len(preds))
# ─── SECTION 11: Dataset, Training Loop & Smoke Test ──────────────────────────
class VideoSequenceDataset(Dataset):
"""
Synthetic video sequence dataset for demonstration.
Replace with YouTube-VOS loader for actual training.
Real datasets:
YouTube-VOS: https://youtube-vos.org/dataset/vos/
DAVIS-16: https://davischallenge.org/davis2016/code.html
FBMS: https://lmb.informatik.uni-freiburg.de/resources/datasets/
"""
def __init__(
self,
num_videos: int = 50,
seq_len: int = 5,
img_size: int = 64, # use 512 for actual training
):
self.num_videos = num_videos
self.seq_len = seq_len
self.img_size = img_size
def __len__(self): return self.num_videos
def __getitem__(self, idx):
"""Returns a video sequence: images, flows, masks all [T, C, H, W]."""
T, S = self.seq_len, self.img_size
images = torch.randn(T, 3, S, S)
flows = torch.randn(T, 3, S, S)
masks = (torch.rand(T, 1, S, S) > 0.7).float()
return images, flows, masks
def train_one_epoch(
model: HMHINet,
loader: DataLoader,
optimizer: torch.optim.Optimizer,
criterion: UVOSLoss,
device: torch.device,
epoch: int,
) -> float:
"""
Train for one epoch over video sequences.
For each sequence: process frames left-to-right, maintain running memory,
collect predictions, compute average loss, backprop.
"""
model.train()
total_loss = 0.0
for step, (images, flows, masks) in enumerate(loader):
images = images.to(device) # (B, T, 3, H, W)
flows = flows.to(device) # (B, T, 3, H, W)
masks = masks.to(device) # (B, T, 1, H, W)
B, T = images.shape[:2]
optimizer.zero_grad()
# Reset memory for each new video sequence
model.reset_memory()
preds_list = []
targets_list = []
for t in range(T):
img_t = images[:, t] # (B, 3, H, W)
flow_t = flows[:, t] # (B, 3, H, W)
mask_t = masks[:, t] # (B, 1, H, W)
# First frame: no memory refinement (Section 4.1)
pred_t = model(img_t, flow_t, is_first_frame=(t == 0))
# Skip first frame loss (no memory context yet)
if t > 0:
preds_list.append(pred_t)
targets_list.append(mask_t)
# Average loss across frames 2..T (Eq. final training loss)
if preds_list:
loss = criterion(preds_list, targets_list)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
if step % 5 == 0:
print(f" Epoch {epoch} | Step {step}/{len(loader)} | Loss {total_loss/(step+1):.4f}")
return total_loss / max(1, len(loader))
def run_training(
cfg: Optional[HMHIConfig] = None,
epochs: int = 3,
device_str: str = "cpu",
use_tiny: bool = True,
) -> HMHINet:
"""
Full training pipeline.
Phase 1 (pre-training): YouTube-VOS, AdamW lr=6e-5, 150 epochs
Phase 2 (fine-tuning): DAVIS-16/FBMS, CosineAnnealingLR lr=1e-4
use_tiny=True: reduced channel dims for fast smoke test.
Set use_tiny=False + GPU for real training.
"""
device = torch.device(device_str)
if use_tiny:
cfg = HMHIConfig(enc_channels=[32, 64, 96, 128], dec_channels=64, num_heads=4)
else:
cfg = cfg or HMHIConfig()
model = HMHINet(cfg).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params/1e6:.2f} M")
dataset = VideoSequenceDataset(num_videos=20, seq_len=5, img_size=64 if use_tiny else 512)
loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr_pretrain, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
criterion = UVOSLoss()
print(f"\n{'='*55}")
print(f" HMHI-Net Training")
print(f" Epochs: {epochs} | Device: {device}")
print(f" Channels: {cfg.enc_channels}")
print(f"{'='*55}\n")
for epoch in range(1, epochs + 1):
avg_loss = train_one_epoch(model, loader, optimizer, criterion, device, epoch)
scheduler.step()
print(f"Epoch {epoch}/{epochs} — Avg Loss: {avg_loss:.4f}\n")
print("Training complete.")
return model
# ─── SMOKE TEST ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
print("=" * 60)
print(" HMHI-Net — Full Architecture Smoke Test")
print("=" * 60)
torch.manual_seed(42)
# ── 1. Build tiny model ───────────────────────────────────────────────────
print("\n[1/5] Instantiating tiny HMHI-Net...")
cfg = HMHIConfig(enc_channels=[32,64,96,128], dec_channels=64, num_heads=4)
model = HMHINet(cfg)
params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {params/1e6:.3f} M")
# ── 2. Single frame (first, no memory) ───────────────────────────────────
print("\n[2/5] First frame forward pass (no memory)...")
img = torch.randn(1, 3, 64, 64)
flow = torch.randn(1, 3, 64, 64)
model.reset_memory()
pred0 = model(img, flow, is_first_frame=True)
print(f" Output mask shape: {tuple(pred0.shape)} (expected: [1, 1, 64, 64])")
assert pred0.shape == (1, 1, 64, 64)
# ── 3. Subsequent frame (with memory) ────────────────────────────────────
print("\n[3/5] Subsequent frame forward pass (with memory)...")
pred1 = model(torch.randn(1, 3, 64, 64), torch.randn(1, 3, 64, 64), is_first_frame=False)
print(f" Output mask shape: {tuple(pred1.shape)}")
print(f" Memory bank sizes: shallow={len(model.mem_shallow.buffer)}, high={len(model.mem_high.buffer)}")
assert pred1.shape == (1, 1, 64, 64)
# ── 4. Loss functions ─────────────────────────────────────────────────────
print("\n[4/5] Loss function check...")
criterion = UVOSLoss()
preds = [torch.rand(2, 1, 64, 64) for _ in range(4)]
targets = [(torch.rand(2, 1, 64, 64) > 0.5).float() for _ in range(4)]
loss_val = criterion(preds, targets)
print(f" Combined BCE+Focal+Dice loss: {loss_val.item():.4f}")
# ── 5. Short training run ─────────────────────────────────────────────────
print("\n[5/5] Short training run (2 epochs, tiny config)...")
run_training(epochs=2, device_str="cpu", use_tiny=True)
print("\n" + "=" * 60)
print("✓ All checks passed. HMHI-Net is ready for use.")
print("=" * 60)
print("""
Next steps:
1. Load pretrained MiT-b1 weights (SegFormer backbone):
pip install timm
import timm
mit_b1 = timm.create_model('mit_b1', pretrained=True)
# Copy weights into HierarchicalEncoder stages
2. Prepare YouTube-VOS dataset:
https://youtube-vos.org/dataset/vos/
→ Pre-train for 150 epochs (lr=6e-5, AdamW)
3. Fine-tune on DAVIS-16:
https://davischallenge.org/davis2016/code.html
→ CosineAnnealingLR, lr=1e-4, until convergence
4. Evaluate with official metrics:
pip install davis2017-evaluation
→ J&F (region similarity + boundary accuracy)
5. Scale to full 512×512 resolution on 4x RTX 4090 GPUs.
cfg = HMHIConfig() # default channels [64, 128, 320, 512]
""")
Code, Paper & Project Page
The official implementation, pretrained checkpoints, and full benchmark evaluation code are available on GitHub. The paper was published at ACM Multimedia 2025.
Zheng, X., He, S., Li, W., Li, X., & Zhang, W. (2025). Shallow Features Matter: Hierarchical Memory with Heterogeneous Interaction for Unsupervised Video Object Segmentation. In MM’25: The 33rd ACM International Conference on Multimedia Proceedings. https://doi.org/10.1145/3746027.3755848. arXiv:2507.22465.
This article is an independent editorial analysis of peer-reviewed research. The PyTorch implementation is an educational adaptation of the methods described in the paper. The original authors trained on four NVIDIA RTX 4090 GPUs; refer to the paper and official GitHub repository for exact training configurations, pretrained weights, and full benchmark results.
