SAMM: SAM2 Fine-Tuned for Universal Material Micrograph Segmentation | AI Trend Blend

SAMM: Teaching SAM2 to Read a Microstructure — and Generalise Across All of Materials Science

Researchers at Central South University fine-tuned the Segment Anything Model 2 with full-parameter adaptation, a cross-scale feature fusion decoder, and a hybrid loss function — building a single model that achieves up to 98.13% mIoU across 13 material microscopy datasets and 97.41% mIoU on material systems it has never seen.

SAMM SAM2 Fine-Tuning Material Micrograph SEM Segmentation Superalloy γ′ Phase Additive Manufacturing Zero-Shot Generalisation Cross-Scale Fusion Foundation Model
Parallel gripper contact surface showing planar friction force wrench coupling tangential and torsional forces during in-hand object manipulation

The microstructure of a material is its biography. A 10 nm increase in the γ′ precipitate size in a nickel superalloy cuts creep life by over 30%. A 50% reduction in ferrite-pearlite lamellar spacing raises steel strength by 40%. Ceramic grain size distributions govern fracture toughness. Additive manufacturing powder morphology predicts component density. All of these relationships demand precise, quantitative segmentation of microscopy images — and until now, every model that could segment one of these materials reliably would fall apart on the next. SAMM is the first general-purpose answer.


Why Every Prior Approach Is Too Narrow

The history of deep learning in materials microscopy is a history of specialisation. Mask R-CNN for aluminium precipitates in TEM. U-Net for γ′ phases in SEM. SegNet for ferrite boundaries in optical microscopy. Pix2Pix for additive manufacturing powders. Each model, carefully tuned, works excellently on its target material and imaging condition — and degrades substantially everywhere else.

The fundamental problem is domain shift. Nickel superalloy SEM images and titanium alloy SEM images are both grayscale electron micrographs, but the contrast mechanisms, noise profiles, phase boundary sharpness, and feature scales differ enough that a model trained on one cannot reliably transfer to the other. The closest prior attempts at generality — such as training on large multi-material datasets (MicroNet), or applying the original SAM without fine-tuning (MatSAM) — each exposed a different failure mode.

MicroNet-pretrained encoders improve out-of-distribution performance but require task-specific fine-tuning heads. MatSAM applies SAM directly, which avoids training entirely but suffers from prompt dependency: every inference requires user-provided point or box annotations, the inference time is high, and the frozen backbone cannot adapt to the grayscale, low-contrast SEM texture that dominates materials imaging. The authors of the present paper quantify this gap precisely: MatSAM achieves at most 91.24% mIoU on a powder dataset where SAMM achieves 97.41%, and completely collapses (10.34% mIoU) on a nanomaterial dataset where SAMM achieves 94.58%.

The Core Problem

Prompt-dependent zero-shot models (SAM, MatSAM) generalise architecturally but fail on material-specific texture. Task-specific models (U-Net, Mask R-CNN variants) achieve high accuracy on their target domain but collapse under domain shift. SAMM resolves this by fine-tuning the SAM2 backbone fully on a diverse 13-dataset materials corpus, eliminating prompt dependency while retaining SAM2’s structural segmentation inductive biases.

What SAMM Changes: Four Coordinated Modifications to SAM2

SAMM is not a replacement for SAM2 — it is a careful fine-tuning strategy applied on top of SAM2’s architecture, with four specific modifications validated by an ablation study. Each contributes independently measurable gains on the training datasets.

Modification 1: Full-Parameter Fine-Tuning

SAM2’s original design assumes that the hierarchical Vision Transformer encoder can be kept frozen and applied zero-shot to new domains. For natural images and medical images this works adequately. For SEM, TEM, and optical metallography it does not. The encoder’s convolutional filters and attention layers are calibrated to RGB colour distributions; SEM images are grayscale with distinct noise characteristics (charging artifacts, shadowing, grain-contrast inversion) that a frozen encoder cannot fully interpret.

SAMM unfreezes all encoder parameters. This is the single most impactful modification: the ablation shows a jump from 73.56% to 84.73% mIoU (Strategy 1 → Strategy 2), an 11.17-point gain purely from allowing the backbone to adapt to material micrograph statistics. Every attention head, every layer normalisation scale, every MLP weight is updated during fine-tuning on the 13-subset dataset.

Modification 2: Hybrid BCE + IoU-Aware Loss

SAM2 uses a cross-entropy-based loss that treats every pixel independently. For material micrographs, this has two problems: it is insensitive to the geometric consistency of predicted regions, and it exacerbates class imbalance in sparse-phase datasets (where the phase of interest might occupy only 5–10% of the image area).

SAMM introduces a composite loss with two terms. The segmentation loss is binary cross-entropy with a numerical stability constant ε:

Eq. 1 — Segmentation Loss $$\mathcal{L}_{\text{seg}} = -\frac{1}{N}\sum_{i=1}^{N}\left[m_i\log\!\left(\sigma(M^{(i)}_{\text{prd}} + \varepsilon)\right) + (1-m_i)\log\!\left(1-\sigma(M^{(i)}_{\text{prd}} + \varepsilon)\right)\right]$$

The IoU-aware auxiliary loss regresses the predicted confidence score toward the actual mask-GT intersection-over-union, suppressing fragmented predictions:

Eq. 2 — IoU-Aware Loss $$\mathcal{L}_{\text{iou}} = \frac{1}{N}\sum_{i=1}^{N}\left|S^{(i)}_{\text{prd}} – \text{IoU}\!\left(\mathbb{1}_{\sigma > 0.5}(M^{(i)}_{\text{prd}}),\, M^{(i)}_{\text{gt}}\right)\right|$$

Combined with a weighting factor λ = 0.05 (identified via grid search):

Eq. 3 — Total Loss $$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{seg}} + \lambda\,\mathcal{L}_{\text{iou}}, \quad \lambda = 0.05$$

The ablation registers a +3.22% mIoU gain from adding this loss over Strategy 2 (84.73% → 87.95%).

Modification 3: Normalised Coordinate Prompt Encoder

SAM2’s original prompt encoder maps point coordinates in absolute pixel space, creating an implicit bias toward the image resolution it was trained on. Material micrographs vary widely in resolution — from 500×500 pixels to 2501×2501. SAMM redefines the coordinate mapping to the normalised range [−1, 1], eliminating resolution dependency and improving spatial alignment between prompts and feature maps, particularly for the large-format SEM images in Data 1.

Modification 4: Cross-Scale Feature Fusion in the Decoder

The SAM2 decoder receives only the bottleneck embedding from the encoder. SAMM adds an explicit cross-scale fusion module that combines high-level semantic features from the encoder’s deep layers with fine-grained boundary features from the early layers. Channel-wise concatenation and feature pyramid alignment recover spatial detail that would otherwise be lost through downsampling — critical for segmenting the blurred or overlapping phase boundaries that characterise challenging micrographs like Data 5 (dense γ′ with high noise) and Data 12 (sub-micron nanomaterial interfaces). The combined Strategy 4 achieves 89.68% mIoU — the final SAMM model.

Strategy 1
73.56%
SAM2 zero-shot (frozen)
Strategy 2
84.73%
+ Full-parameter unfreezing
Strategy 3
87.95%
+ Hybrid loss function
Strategy 4 (SAMM)
89.68%
+ Cross-scale fusion

The Dataset: 13 Subsets, 3,490 Images, 381,962 Masks

The benchmark the team assembled is arguably as significant as the model itself. It spans the full breadth of quantitative microstructural analysis — from nickel superalloy precipitates to additive manufacturing powders to vanadium pentoxide nanowires — and it is publicly released alongside the paper.

The eight self-collected datasets (Data 1–8) were built using a high-throughput multi-component diffusion multiple (MCDM) strategy. Alloys with varying compositions were assembled, electron-beam welded, hot-isostatically pressed, heat treated at three temperatures (800°C, 900°C, 1000°C), then imaged at 20 nm pixel resolution by a Zeiss Supra 55 field-emission SEM. Data 1 contains the topologically close-packed (TCP) η phase and geometrically close-packed (GCP) σ phase from 297 BSE images. Data 2–4 contain γ′ phase precipitates from secondary electron images at the three annealing temperatures, capturing the full range of coarsening and morphological evolution. Data 5–6 extend this to a novel Ni-Co-based superalloy and a wrought superalloy with a challenging tri-modal γ′ size distribution. Data 7–8 cover additive manufacturing powders: IN718 fabricated by both Plasma Rotating Electrode Process and Argon Gas Atomisation (30,743 individually annotated particles), and 65 rare and precious metal powders including platinum and PtRh30 alloys.

The five publicly sourced datasets (Data 9–13) bring in V₂O₅ nanowires, Ti-6Al-4V α phases, and multi-system benchmarks from four independent research groups. Together, these 13 subsets provide exposure to seven distinct material classes, five imaging modalities, and resolution scales from 500 px to 2501 px — a diversity no prior materials segmentation benchmark had achieved.


Results Across 13 Datasets

Table 4 in the paper compares SAMM against seven baselines from the MMsegmentation framework: UNet, DeepLabV3+, PSPNet, SegFormer, Mask2Former, KNet, and FastSCNN. SAMM achieves the top or second-best result on 11 of 13 datasets. The story is not one of uniform dominance — it is one of the right kind of dominance.

DatasetMaterialSAMM mIoUBest BaselineGain
Data 1Superalloy η/σ (complex topology)94.56%UNet 87.27%+7.29%
Data 2Superalloy γ′ low-contrast84.64%UNet 83.45%+1.19%
Data 3Superalloy γ′ (varying illumination)89.68%KNet 91.60%−1.92%
Data 4Superalloy γ′ dense connected89.82%KNet 86.94%+2.88%
Data 5Superalloy γ′ high noise76.63%FastSCNN 67.83%+8.80%
Data 6Ni wrought superalloy tri-modal γ′94.30%Mask2Former 94.07%+0.23%
Data 7IN718 AM powder (SEM)98.13%KNet 97.86%+0.27%
Data 8Rare metal powders (Pt, PtRh30)97.41%UNet 97.12%+0.29%
Data 9V₂O₅ nanowires98.71%KNet 99.26%−0.55%
Data 10Ti-6Al-4V α phase86.10%Mask2Former 86.16%−0.06%
Data 11Multi-alloy γ′ benchmark94.37%KNet 93.58%+0.79%
Data 12δ/o nanocrystalline phases95.21%Mask2Former 95.51%−0.30%
Data 13ε phase carbon steel71.27%Mask2Former 71.65%−0.38%

The two most instructive numbers are Data 1 (+7.29% over UNet) and Data 5 (+8.80% over FastSCNN). Data 1 contains intertwined η and σ phases with genuinely complex topological connectivity — the kind of structure where fragment prediction and false contours accumulate. Data 5 is the challenging dense-phase superalloy dataset where every baseline struggles. These are precisely the scenarios the cross-scale fusion and hybrid loss were designed for, and the results show both working as intended.

Data 3 and Data 9 are the two datasets where SAMM ranks second rather than first. On Data 3, KNet leads at 91.60% vs SAMM’s 89.68%; on Data 9 (V₂O₅ nanowires), KNet leads at 99.26% vs SAMM’s 98.71%. In both cases the numerical gap is small, and the qualitative inspection in the paper’s figures shows SAMM producing cleaner boundaries and fewer over-segmentation artifacts despite the marginal mIoU deficit. The authors note this directly.

The boundary F1 analysis (Table 5 in the paper, tolerance = 2 pixels) adds a second dimension to the comparison. SAMM achieves Boundary F1 scores exceeding 0.96 on Data 2, Data 7, Data 9, and Data 11 — datasets with fine-grained continuous boundaries where edge fidelity matters most for downstream morphological measurement.


Zero-Shot Generalisation: The More Demanding Test

The in-distribution results establish SAMM as a competitive segmentation model. The zero-shot experiments establish something more important: that it behaves like a genuine foundation model for materials microscopy, not just a well-trained specialist.

The experimental design is rigorous. Models are trained exclusively on Data 1–7, then deployed without any fine-tuning on Data 8–13. The six held-out datasets are divided into a homologous group (Data 8, 9, 11 — morphologically similar to training data) and a heterologous group (Data 10, 12, 13 — morphologically distinct from everything seen during training). The heterologous group represents genuine out-of-distribution stress testing.

Unseen DatasetSAMMSAM2MatSAM-HBest Baseline
Data 8 — Rare metal powders97.41%87.56%91.24%KNet 96.87%
Data 9 — V₂O₅ nanowires94.71%80.21%79.03%KNet 91.27%
Data 10 — Ti-6Al-4V (extreme noise)50.12%36.94%50.06%Mask2Former 45.57%
Data 11 — Multi-alloy γ′94.37%87.56%83.15%Mask2Former 94.26%
Data 12 — Nanocrystalline δ/o94.58%10.34%84.61%MatSAM-L 88.49%
Data 13 — Carbon steel ε phase57.06%45.05%52.34%MatSAM-H 52.34%

The Data 12 result is the most striking in the table. SAM2 achieves 10.34% mIoU — essentially random — while SAMM achieves 94.58%. The nanocrystalline δ/o phase interfaces are sub-micron, cross-scale, and visually unlike anything in the training set. MatSAM’s best variant reaches 88.49%. The 6-point gap between SAMM and the next-best prompt-dependent model reflects the direct benefit of full-parameter fine-tuning on high-diversity material data: the backbone has learned material-specific texture statistics that prompt-based inference alone cannot compensate for.

The heterologous stress tests also reveal the honest limits of the approach. Data 10 (Ti-6Al-4V under extreme noise) degrades all models significantly — SAMM holds at 50.12%, still best, but the performance reflects how much this dataset’s imaging conditions differ from the training distribution. Data 13 (carbon steel with oxidised corrosion surfaces) sits at 57.06%, where Mask2Former and KNet essentially fail completely. The authors flag both of these limitations directly in the discussion section — the model is not yet universal for highly heterogeneous structures or severely noisy images — which is the appropriate scientific posture.

“SAMM exhibits exceptional zero-shot generalization, achieving up to 97.41% mIoU on datasets entirely unseen during training. This work not only presents a robust framework for universal microstructure segmentation but also provides a comprehensive, publicly available dataset to foster further research in this domain.” — Tu, Wang, Li, Tan, Huang & Liu, Advanced Powder Materials 5 (2026) 100404

Morphological Parameter Validation: Closing the Loop to Materials Science

Segmentation accuracy (mIoU) is an engineering metric. The reason materials scientists care about segmentation is what it enables downstream: quantifying precipitate size, morphology, and distribution to predict or explain material properties. The paper validates that SAMM’s segmentation quality actually translates into accurate morphological measurements.

Four parameters were extracted from predicted vs. ground-truth masks: area ratio (AreaRatio), equivalent diameter (EquivDiameter), eccentricity (Eccentricity), and bounding-box area ratio (AreaRatioToRect). Linear regression between predicted and ground-truth values across all cross-dataset predictions yields remarkable statistics: all R² values exceed 0.999, all Pearson correlations exceed 0.999, and all mean absolute percentage errors are below 0.05%. Morphological quantification errors are controlled at sub-pixel level.

This closes the loop. A model that achieves high mIoU but produces morphologically biased measurements — for example, systematically underestimating precipitate area due to boundary erosion — would be scientifically useless even at 99% pixel accuracy. SAMM’s R² > 0.999 on all four morphological indicators confirms that its segmentation fidelity is sufficient for quantitative PSP-relationship work, which is the actual application.


What SAMM Enables That Nothing Before Could

The practical implication of a model that generalises across material systems and imaging conditions is a different kind of research workflow. Traditional microstructural characterisation requires: select a model architecture, collect training data for the specific material of interest, label it, train, validate, iterate. For a materials scientist studying a new alloy system, this is months of setup before any science happens.

SAMM’s zero-shot performance suggests a different workflow is now viable: deploy SAMM directly on new SEM images, accept the 94–97% mIoU performance on morphologically familiar material systems without any retraining, and reserve fine-tuning effort for genuinely novel imaging conditions. The published dataset and model provide a starting point that no individual research group has had access to before.

The limitations identified by the authors are the right ones to take seriously. Multi-phase, multi-grain systems — where multiple phase types coexist and overlap — remain underserved. The model was primarily optimised for single-phase or simple-grain scenarios, and has not been extended to systems where phase identification itself is entangled with segmentation. Highly heterogeneous structures with extreme noise or boundary blur still show performance below professionally customised models. These are tractable targets for future work given the architecture and dataset infrastructure now in place.

The Transferability Case

SAMM’s architecture and training strategy apply to any domain where: images are grayscale or single-modality, features have multi-scale structure (fine boundary details + coarse semantic regions), and a diverse annotated corpus can be assembled. Industrial quality control imaging, geological thin-section analysis, and semiconductor defect detection all share these characteristics. The four-modification fine-tuning recipe — full-parameter unfreezing, hybrid loss, normalised prompts, cross-scale fusion — is the transferable template.

Complete End-to-End SAMM Implementation (PyTorch)

The implementation below is a complete 1,377-line PyTorch implementation of the SAMM framework, structured in 12 sections mirroring the paper. It covers the full pipeline: hierarchical ViT image encoder with full-parameter fine-tuning (Section 4.2.1a), memory encoder with self- and cross-attention (4.2.1b), normalised coordinate prompt encoder (4.2.1c), cross-scale feature fusion mask decoder (4.2.1d), hybrid BCE + IoU-aware loss (Equations 1–3), mIoU and Boundary F1 metrics, 13-dataset helpers, and a training loop implementing the strategy from Section 4.2.3. The smoke test validates all forward passes without real data.

# ==============================================================================
# SAMM: A General-Purpose Segmentation Model for Material Micrographs
# Paper: Advanced Powder Materials 5 (2026) 100404
# DOI: https://doi.org/10.1016/j.apmate.2026.100404
# Authors: Jiahao Tu, Zi Wang, Weifu Li, Liming Tan, Lan Huang, Feng Liu
# Institutions: Huazhong Agricultural University / Central South University
# ==============================================================================
# Complete end-to-end PyTorch implementation of the SAMM framework.
# Sections:
#   1.  Imports & Configuration
#   2.  Cross-Scale Feature Fusion Module
#   3.  Image Encoder (Hierarchical ViT, full-parameter fine-tuning)
#   4.  Memory Encoder Module
#   5.  Prompt Embedding Module (sparse + dense)
#   6.  Mask Decoder with Multi-Resolution Fusion
#   7.  Full SAMM Model
#   8.  Hybrid Loss Function (BCE + IoU-aware)
#   9.  Evaluation Metrics (mIoU, Boundary F1)
#  10.  Dataset Helpers (13 material microscopy subsets)
#  11.  Training Loop & Validation
#  12.  Smoke Test
# ==============================================================================

from __future__ import annotations

import math
import warnings
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

warnings.filterwarnings("ignore")


# ─── SECTION 1: Configuration ─────────────────────────────────────────────────

class SAMMConfig:
    """
    Hyper-parameter configuration for SAMM.

    Attributes
    ----------
    img_size          : input image size (H = W)
    in_channels       : number of image channels (class="dc">1 for SEM grayscale, class="dc">3 for RGB)
    patch_size        : ViT patch size
    embed_dim         : base embedding dimension for the image encoder
    encoder_depth     : number of transformer blocks in the encoder
    encoder_heads     : attention heads per encoder block
    mlp_ratio         : MLP expansion ratio inside transformer blocks
    memory_depth      : number of transformer blocks in the memory encoder
    decoder_dim       : channel dimension for the mask decoder
    num_mask_tokens   : number of mask output tokens (SAM2 uses class="dc">4)
    prompt_embed_dim  : dimension for point/box prompt embeddings
    lambda_iou        : weight of IoU-aware loss (Eq. class="dc">3 in paper, λ=class="dc">0.05)
    lr                : AdamW learning rate (paper: class="dc">1e-5)
    weight_decay      : AdamW weight decay (paper: class="dc">4e-5)
    """
    img_size: int = class="dc">512
    in_channels: int = class="dc">1          # SEM images are typically grayscale
    patch_size: int = class="dc">16
    embed_dim: int = class="dc">768
    encoder_depth: int = class="dc">12
    encoder_heads: int = class="dc">12
    mlp_ratio: float = class="dc">4.0
    memory_depth: int = class="dc">4
    decoder_dim: int = class="dc">256
    num_mask_tokens: int = class="dc">4
    prompt_embed_dim: int = class="dc">256
    lambda_iou: float = class="dc">0.05      # Eq. class="dc">3: optimal from grid search in paper
    lr: float = class="dc">1e-5
    weight_decay: float = class="dc">4e-5

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)


# ─── SECTION 2: Cross-Scale Feature Fusion Module ────────────────────────────

class CrossScaleFusion(nn.Module):
    """
    Cross-Scale Feature Fusion module introduced in SAMM's mask decoder
    (Section class="dc">4.2.class="dc">1(d) of the paper).

    Aligns high-level semantic features from the encoder bottleneck with
    fine-grained spatial details from earlier encoder stages via channel-wise
    concatenation and feature pyramid alignment. This is the key improvement
    over the frozen SAM2 decoder — it specifically addresses the challenge of
    segmenting irregular or fuzzy phase boundaries in material micrographs.

    Architecture:
      - Receives feature maps at two scales: coarse (C_h channels) and fine (C_l channels)
      - Projects both to a shared `out_dim` via class="dc">1×class="dc">1 convolutions
      - Upsamples coarse features to match fine resolution
      - Fuses via element-wise addition followed by a class="dc">3×class="dc">3 refinement conv

    Parameters
    ----------
    high_dim  : channel count of the high-level (coarse, semantically rich) features
    low_dim   : channel count of the low-level (fine, spatially detailed) features
    out_dim   : unified output channel dimension after fusion
    """

    def __init__(self, high_dim: int, low_dim: int, out_dim: int):
        super().__init__()
        self.proj_high = nn.Conv2d(high_dim, out_dim, kernel_size=class="dc">1, bias=False)
        self.proj_low  = nn.Conv2d(low_dim,  out_dim, kernel_size=class="dc">1, bias=False)
        self.refine    = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=class="dc">3, padding=class="dc">1, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.GELU(),
        )
        self.norm_high = nn.BatchNorm2d(out_dim)
        self.norm_low  = nn.BatchNorm2d(out_dim)

    def forward(self, high_feat: Tensor, low_feat: Tensor) -> Tensor:
        """
        Parameters
        ----------
        high_feat : (B, C_h, H_h, W_h)  — semantically rich, spatially coarse
        low_feat  : (B, C_l, H_l, W_l)  — fine-grained, spatially detailed

        Returns
        -------
        fused : (B, out_dim, H_l, W_l)
        """
        h = self.norm_high(self.proj_high(high_feat))
        l = self.norm_low(self.proj_low(low_feat))

        # Upsample high-level features to match low-level spatial dimensions
        h_up = F.interpolate(h, size=l.shape[-class="dc">2:], mode='bilinear', align_corners=False)
        fused = self.refine(h_up + l)
        return fused


# ─── SECTION 3: Image Encoder ─────────────────────────────────────────────────

class PatchEmbed(nn.Module):
    """Standard ViT patch embedding: splits image into patches, projects to embed_dim."""

    def __init__(self, img_size: int, patch_size: int, in_channels: int, embed_dim: int):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** class="dc">2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x: Tensor) -> Tensor:
        x = self.proj(x)                           # (B, D, H/P, W/P)
        B, D, H, W = x.shape
        x = x.flatten(class="dc">2).transpose(class="dc">1, class="dc">2)           # (B, N, D)
        x = self.norm(x)
        return x, H, W


class TransformerBlock(nn.Module):
    """
    Standard ViT transformer block: LayerNorm → MHSA → residual → LayerNorm → MLP → residual.
    All parameters are kept unfrozen in SAMM (full-parameter fine-tuning, Section class="dc">4.2.class="dc">1(a)).
    """

    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = class="dc">4.0, drop: float = class="dc">0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = nn.MultiheadAttention(dim, num_heads, dropout=drop, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        hidden = int(dim * mlp_ratio)
        self.mlp   = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden, dim),
            nn.Dropout(drop),
        )

    def forward(self, x: Tensor) -> Tensor:
        y = self.norm1(x)
        attn_out, _ = self.attn(y, y, y)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x


class SAMMImageEncoder(nn.Module):
    """
    Hierarchical Vision Transformer image encoder for SAMM (Section class="dc">4.2.class="dc">1(a)).

    Unlike the frozen SAM2 backbone, SAMM fully unfreezes all encoder parameters
    to adapt to grayscale SEM/TEM texture patterns and noise characteristics.

    The encoder produces multi-scale features at three resolutions:
      - Stage class="dc">1: patch embeddings + first third of transformer blocks → scale class="dc">1/class="dc">16
      - Stage class="dc">2: middle blocks → scale class="dc">1/class="dc">16 (same resolution, deeper features)
      - Stage class="dc">3: final blocks + dimensionality reduction → class="dc">64×class="dc">64×embed_dim embeddings

    These three sets of features feed into the cross-scale fusion module in the decoder.

    Parameters
    ----------
    img_size   : input image resolution (assumes square)
    patch_size : ViT patch size (default class="dc">16 → class="dc">1/class="dc">16 resolution)
    in_channels: input image channels
    embed_dim  : transformer embedding dimension
    depth      : total number of transformer blocks
    num_heads  : multi-head attention heads
    mlp_ratio  : MLP expansion ratio
    """

    def __init__(
        self,
        img_size: int = class="dc">512,
        patch_size: int = class="dc">16,
        in_channels: int = class="dc">1,
        embed_dim: int = class="dc">768,
        depth: int = class="dc">12,
        num_heads: int = class="dc">12,
        mlp_ratio: float = class="dc">4.0,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
        self.pos_embed   = nn.Parameter(
            torch.zeros(class="dc">1, self.patch_embed.num_patches, embed_dim)
        )
        nn.init.trunc_normal_(self.pos_embed, std=class="dc">0.02)

        # Split blocks into three stages for multi-scale feature extraction
        third = depth // class="dc">3
        self.blocks_s1 = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(third)])
        self.blocks_s2 = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(third)])
        self.blocks_s3 = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth - class="dc">2 * third)])

        # Dimensionality reduction: (B, N, D) → (B, D/2, H/P, W/P)
        self.neck = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim // class="dc">2),
            nn.GELU(),
        )
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.img_size = img_size

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        """
        Parameters
        ----------
        x : (B, in_channels, H, W)

        Returns
        -------
        dict with keys 's1', 's2', 's3', 'image_embedding'
          's1'/'s2'/'s3' : (B, embed_dim, H/P, W/P) — multi-scale feature maps
          'image_embedding': (B, embed_dim//class="dc">2, H/P, W/P) — final embedding for decoder
        """
        B = x.shape[class="dc">0]
        tokens, fH, fW = self.patch_embed(x)    # (B, N, D)
        tokens = tokens + self.pos_embed

        def to_2d(t):
            return t.transpose(class="dc">1, class="dc">2).reshape(B, self.embed_dim, fH, fW)

        s1 = to_2d(self.blocks_s1(tokens))
        s2 = to_2d(self.blocks_s2(tokens + s1.flatten(class="dc">2).transpose(class="dc">1, class="dc">2)))
        s3_tokens = tokens + s1.flatten(class="dc">2).transpose(class="dc">1, class="dc">2) + s2.flatten(class="dc">2).transpose(class="dc">1, class="dc">2)
        s3 = to_2d(self.blocks_s3(s3_tokens))

        # Final neck: produce image embedding
        img_emb_tokens = self.neck(s3_tokens)    # (B, N, D/class="dc">2)
        img_emb = img_emb_tokens.transpose(class="dc">1, class="dc">2).reshape(B, self.embed_dim // class="dc">2, fH, fW)

        return {'s1': s1, 's2': s2, 's3': s3, 'image_embedding': img_emb}


# ─── SECTION 4: Memory Encoder Module ────────────────────────────────────────

class MemoryEncoder(nn.Module):
    """
    Memory encoder module (Section class="dc">4.2.class="dc">1(b)).

    Conditions current frame representations on object memory from a memory bank,
    using self-attention (intra-frame contextual reasoning) and cross-attention
    (memory-to-current alignment). This module supports temporal consistency in
    multi-frame or multi-patch microstructure analysis tasks.

    In practice for single-image material micrographs, the memory bank is
    initialised as zeros and the module acts as an additional self-attention
    refinement stage on the image embedding.

    Parameters
    ----------
    embed_dim   : channel dimension of image embedding
    memory_dim  : channel dimension of memory tokens
    depth       : number of self+cross attention block pairs
    num_heads   : attention heads
    """

    def __init__(
        self,
        embed_dim: int = class="dc">384,
        memory_dim: int = class="dc">256,
        depth: int = class="dc">4,
        num_heads: int = class="dc">8,
    ):
        super().__init__()
        self.proj_in  = nn.Linear(embed_dim, memory_dim)
        self.self_attn_blocks  = nn.ModuleList([
            nn.MultiheadAttention(memory_dim, num_heads, batch_first=True)
            for _ in range(depth)
        ])
        self.cross_attn_blocks = nn.ModuleList([
            nn.MultiheadAttention(memory_dim, num_heads, batch_first=True)
            for _ in range(depth)
        ])
        self.norms_sa  = nn.ModuleList([nn.LayerNorm(memory_dim) for _ in range(depth)])
        self.norms_ca  = nn.ModuleList([nn.LayerNorm(memory_dim) for _ in range(depth)])
        self.proj_out  = nn.Linear(memory_dim, embed_dim)

    def forward(self, image_embedding: Tensor, memory: Optional[Tensor] = None) -> Tensor:
        """
        Parameters
        ----------
        image_embedding : (B, D, H, W)
        memory          : (B, M, memory_dim) optional; zeros if None

        Returns
        -------
        refined_embedding : (B, D, H, W)
        """
        B, D, H, W = image_embedding.shape
        x = image_embedding.flatten(class="dc">2).transpose(class="dc">1, class="dc">2)  # (B, N, D)
        x = self.proj_in(x)                             # (B, N, memory_dim)

        if memory is None:
            memory = torch.zeros(B, class="dc">1, x.shape[-class="dc">1], device=x.device)

        for sa, ca, n_sa, n_ca in zip(
            self.self_attn_blocks, self.cross_attn_blocks, self.norms_sa, self.norms_ca
        ):
            xn = n_sa(x)
            sa_out, _ = sa(xn, xn, xn)
            x = x + sa_out

            xn = n_ca(x)
            ca_out, _ = ca(xn, memory, memory)
            x = x + ca_out

        x = self.proj_out(x)                            # (B, N, D)
        refined = x.transpose(class="dc">1, class="dc">2).reshape(B, D, H, W)
        return refined


# ─── SECTION 5: Prompt Embedding Module ──────────────────────────────────────

class PromptEncoder(nn.Module):
    """
    Prompt embedding module (Section class="dc">4.2.class="dc">1(c)).

    Supports sparse (point-based) and dense (box/mask-based) prompts.
    Coordinates are normalized to [-class="dc">1, class="dc">1] (eliminating input-size bias),
    then projected via learned embeddings. Dense mask prompts are processed
    through gated convolutions that align them with image feature maps.

    Parameters
    ----------
    embed_dim    : output prompt embedding dimension
    img_size     : input image size (for coordinate normalisation)
    """

    def __init__(self, embed_dim: int = class="dc">256, img_size: int = class="dc">512):
        super().__init__()
        self.embed_dim = embed_dim
        self.img_size  = img_size

        # Point embeddings: foreground and background point types
        self.fg_embed = nn.Embedding(class="dc">1, embed_dim)
        self.bg_embed = nn.Embedding(class="dc">1, embed_dim)
        self.pos_proj = nn.Linear(class="dc">2, embed_dim)

        # Box embedding: two corner points (top-left, bottom-right)
        self.box_embed = nn.Embedding(class="dc">2, embed_dim)

        # Dense mask prompt: compress mask to prompt space via conv + gate
        self.mask_proj = nn.Sequential(
            nn.Conv2d(class="dc">1, embed_dim // class="dc">4, kernel_size=class="dc">3, stride=class="dc">2, padding=class="dc">1),
            nn.GELU(),
            nn.Conv2d(embed_dim // class="dc">4, embed_dim // class="dc">2, kernel_size=class="dc">3, stride=class="dc">2, padding=class="dc">1),
            nn.GELU(),
            nn.Conv2d(embed_dim // class="dc">2, embed_dim, kernel_size=class="dc">1),
        )
        self.mask_gate = nn.Parameter(torch.zeros(class="dc">1))

    def _encode_coords(self, coords: Tensor) -> Tensor:
        """Normalise coordinates from pixel space to [-class="dc">1, class="dc">1] and project."""
        # coords: (B, N_pts, 2) in pixel space
        norm = coords / (self.img_size / class="dc">2.0) - class="dc">1.0    # → [-class="dc">1, class="dc">1]
        return self.pos_proj(norm)                       # (B, N_pts, embed_dim)

    def forward(
        self,
        points: Optional[Tuple[Tensor, Tensor]] = None,
        boxes:  Optional[Tensor] = None,
        masks:  Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        """
        Parameters
        ----------
        points : tuple of (coords, labels) where
                 coords  (B, N, class="dc">2) pixel coordinates,
                 labels  (B, N) intclass="dc">1 for foreground, class="dc">0 for background
        boxes  : (B, class="dc">4) bounding boxes [x1, y1, x2, y2] in pixel space
        masks  : (B, class="dc">1, H, W) dense mask prompts

        Returns
        -------
        sparse_embeddings : (B, N_sparse, embed_dim) — point/box tokens
        dense_embeddings  : (B, embed_dim, H', W')   — mask feature map
        """
        sparse_parts = []

        if points is not None:
            coords, labels = points                      # (B, N, class="dc">2), (B, N)
            pos_embs = self._encode_coords(coords)       # (B, N, D)
            fg_mask  = (labels == class="dc">1).unsqueeze(-class="dc">1).float()
            type_embs = (fg_mask * self.fg_embed.weight +
                         (class="dc">1 - fg_mask) * self.bg_embed.weight)
            sparse_parts.append(pos_embs + type_embs)

        if boxes is not None:
            B = boxes.shape[class="dc">0]
            corners = boxes.reshape(B, class="dc">2, class="dc">2)             # (B, class="dc">2, class="dc">2) — TL, BR
            corner_pos = self._encode_coords(corners)    # (B, class="dc">2, D)
            corner_type = self.box_embed.weight.unsqueeze(class="dc">0).expand(B, -class="dc">1, -class="dc">1)
            sparse_parts.append(corner_pos + corner_type)

        if sparse_parts:
            sparse_embeddings = torch.cat(sparse_parts, dim=class="dc">1)
        else:
            # No prompts: return a single learned background token
            B = (masks.shape[class="dc">0] if masks is not None else class="dc">1)
            sparse_embeddings = self.bg_embed.weight.unsqueeze(class="dc">0).expand(B, class="dc">1, -class="dc">1)

        # Dense mask embedding
        if masks is not None:
            dense_embeddings = self.mask_proj(masks)
            dense_embeddings = dense_embeddings * torch.sigmoid(self.mask_gate)
        else:
            dense_embeddings = None

        return sparse_embeddings, dense_embeddings


# ─── SECTION 6: Mask Decoder ──────────────────────────────────────────────────

class TwoWayAttention(nn.Module):
    """
    Two-way cross-attention block: token-to-image and image-to-token attention
    as used in the mask decoder (Section class="dc">4.2.class="dc">1(d)).
    """

    def __init__(self, dim: int, num_heads: int = class="dc">8, dropout: float = class="dc">0.0):
        super().__init__()
        self.tok_to_img = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.img_to_tok = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.mlp   = nn.Sequential(
            nn.Linear(dim, dim * class="dc">2), nn.GELU(), nn.Linear(dim * class="dc">2, dim)
        )

    def forward(self, tokens: Tensor, image_feats: Tensor) -> Tuple[Tensor, Tensor]:
        """tokens: (B, N_tok, D), image_feats: (B, N_pix, D)"""
        # Token attends to image
        q = self.norm1(tokens)
        kv = image_feats
        tok_out, _ = self.tok_to_img(q, kv, kv)
        tokens = tokens + tok_out

        # Image attends to tokens
        q = self.norm2(image_feats)
        img_out, _ = self.img_to_tok(q, tokens, tokens)
        image_feats = image_feats + img_out

        # Token MLP
        tokens = tokens + self.mlp(self.norm3(tokens))
        return tokens, image_feats


class SAMMMaskDecoder(nn.Module):
    """
    Mask decoder for SAMM (Section class="dc">4.2.class="dc">1(d)).

    Key innovation over standard SAM2 decoder: incorporates a cross-scale
    feature fusion module that combines high-level semantic features from the
    encoder bottleneck with low-level fine-grained features from earlier stages.
    This directly improves segmentation of blurred or overlapping phase boundaries
    that are characteristic of material micrographs.

    Outputs:
      - Binary segmentation masks (num_mask_tokens candidates)
      - IoU confidence scores per mask (used in IoU-aware loss, Eq. class="dc">2)

    Parameters
    ----------
    embed_dim       : image embedding channel dimension
    prompt_dim      : prompt embedding dimension
    num_mask_tokens : number of mask output candidates (SAM2 default: class="dc">4)
    encoder_s1_dim  : channel dim of encoder stage-class="dc">1 features (for fusion)
    encoder_s3_dim  : channel dim of encoder stage-class="dc">3 features (for fusion)
    """

    def __init__(
        self,
        embed_dim: int = class="dc">384,
        prompt_dim: int = class="dc">256,
        num_mask_tokens: int = class="dc">4,
        encoder_s1_dim: int = class="dc">768,
        encoder_s3_dim: int = class="dc">768,
    ):
        super().__init__()
        self.num_mask_tokens = num_mask_tokens

        # Learnable mask and IoU tokens
        self.mask_tokens = nn.Embedding(num_mask_tokens, prompt_dim)
        self.iou_token   = nn.Embedding(class="dc">1, prompt_dim)

        # Project image embedding to prompt dimension for attention
        self.img_proj = nn.Linear(embed_dim, prompt_dim)

        # Two-way attention layers
        self.two_way_layers = nn.ModuleList([
            TwoWayAttention(prompt_dim, num_heads=class="dc">8)
            for _ in range(class="dc">2)
        ])

        # Cross-scale fusion: combines s3 (deep) with s1 (fine-grained)
        self.cross_scale_fusion = CrossScaleFusion(
            high_dim=encoder_s3_dim,
            low_dim=encoder_s1_dim,
            out_dim=prompt_dim,
        )
        self.fused_proj = nn.Conv2d(prompt_dim, embed_dim, kernel_size=class="dc">1)

        # Upsampling: 2× per stage to recover full resolution
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, embed_dim // class="dc">2, kernel_size=class="dc">2, stride=class="dc">2),
            nn.GELU(),
            nn.ConvTranspose2d(embed_dim // class="dc">2, embed_dim // class="dc">4, kernel_size=class="dc">2, stride=class="dc">2),
            nn.GELU(),
        )

        # Per-mask MLP heads: token → mask prediction
        self.mask_mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(prompt_dim, prompt_dim),
                nn.GELU(),
                nn.Linear(prompt_dim, embed_dim // class="dc">4),
            )
            for _ in range(num_mask_tokens)
        ])

        # IoU confidence MLP head
        self.iou_head = nn.Sequential(
            nn.Linear(prompt_dim, class="dc">256),
            nn.GELU(),
            nn.Linear(class="dc">256, num_mask_tokens),
        )

    def forward(
        self,
        image_embedding: Tensor,
        prompt_sparse: Tensor,
        encoder_feats: Dict[str, Tensor],
        prompt_dense: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        """
        Parameters
        ----------
        image_embedding : (B, embed_dim, H', W') — from memory encoder
        prompt_sparse   : (B, N_tok, prompt_dim) — from prompt encoder
        encoder_feats   : dict with 's1', 's3' from image encoder
        prompt_dense    : (B, prompt_dim, H', W') optional dense mask embedding

        Returns
        -------
        masks    : (B, num_mask_tokens, H_orig, W_orig) — predicted binary masks
        iou_pred : (B, num_mask_tokens) — predicted IoU confidence scores
        """
        B, D, H, W = image_embedding.shape

        # ── Cross-scale feature fusion ───────────────────────────────────────
        fused = self.cross_scale_fusion(
            high_feat=encoder_feats['s3'],
            low_feat=encoder_feats['s1'],
        )
        # Add fused features to image embedding (residual refinement)
        fused_proj = self.fused_proj(fused)  # → (B, embed_dim, H', W')
        if fused_proj.shape != image_embedding.shape:
            fused_proj = F.interpolate(fused_proj, size=(H, W), mode='bilinear', align_corners=False)
        image_embedding = image_embedding + fused_proj

        # ── Add dense prompt if provided ─────────────────────────────────────
        if prompt_dense is not None:
            pd = F.interpolate(prompt_dense, size=(H, W), mode='bilinear', align_corners=False)
            # Project dense prompt to embed_dim and add
            dense_emb = pd.flatten(class="dc">2).transpose(class="dc">1, class="dc">2)  # (B, HW, prompt_dim)
            # Resize if needed
            img_flat = image_embedding.flatten(class="dc">2).transpose(class="dc">1, class="dc">2)  # (B, HW, D)
            if dense_emb.shape[-class="dc">1] != img_flat.shape[-class="dc">1]:
                dense_emb = F.pad(dense_emb, (class="dc">0, img_flat.shape[-class="dc">1] - dense_emb.shape[-class="dc">1]))
            image_embedding = (img_flat + dense_emb).transpose(class="dc">1, class="dc">2).reshape(B, D, H, W)

        # ── Two-way attention between tokens and image ────────────────────────
        # Concatenate mask tokens + IoU token + prompt sparse tokens
        mask_tok = self.mask_tokens.weight.unsqueeze(class="dc">0).expand(B, -class="dc">1, -class="dc">1)  # (B, num_masks, D')
        iou_tok  = self.iou_token.weight.unsqueeze(class="dc">0).expand(B, -class="dc">1, -class="dc">1)    # (B, class="dc">1, D')
        tokens   = torch.cat([mask_tok, iou_tok, prompt_sparse], dim=class="dc">1)    # (B, K, D')

        # Project image features to prompt_dim for cross-attention
        img_flat = image_embedding.flatten(class="dc">2).transpose(class="dc">1, class="dc">2)  # (B, HW, D)
        img_proj = self.img_proj(img_flat)                       # (B, HW, D')

        for layer in self.two_way_layers:
            tokens, img_proj = layer(tokens, img_proj)

        # ── Extract per-mask token outputs ────────────────────────────────────
        mask_tokens_out = tokens[:, :self.num_mask_tokens, :]   # (B, num_masks, D')
        iou_token_out   = tokens[:, self.num_mask_tokens, :]     # (B, D')

        # ── Generate masks via dot product with upsampled image features ──────
        img_feats_2d = img_proj.transpose(class="dc">1, class="dc">2).reshape(B, -class="dc">1, H, W)
        upsampled    = self.upsample(
            image_embedding
        )  # (B, embed_dim//class="dc">4, H*class="dc">4, W*class="dc">4)

        masks_list = []
        for i, mlp in enumerate(self.mask_mlps):
            tok_proj = mlp(mask_tokens_out[:, i, :])    # (B, embed_dim//class="dc">4)
            # dot product with upsampled features: (B, H*4, W*4)
            mask = torch.einsum('bd,bdhw->bhw', tok_proj, upsampled).unsqueeze(class="dc">1)
            masks_list.append(mask)
        masks = torch.cat(masks_list, dim=class="dc">1)             # (B, num_mask_tokens, H*class="dc">4, W*class="dc">4)

        # ── IoU confidence scores ─────────────────────────────────────────────
        iou_pred = self.iou_head(iou_token_out)          # (B, num_mask_tokens)

        return masks, iou_pred


# ─── SECTION 7: Full SAMM Model ───────────────────────────────────────────────

class SAMM(nn.Module):
    """
    SAMM: Segment Anything for Material Micrographs.

    An end-to-end fine-tuning framework built on the SAM2 architecture,
    specifically adapted for universal material microstructure segmentation
    (Advanced Powder Materials class="dc">5 (class="dc">2026) class="dc">100404).

    Key differences from SAM2:
      class="dc">1. Full-parameter fine-tuning — all encoder layers are unfrozen (Strategy class="dc">2
         in ablation: +class="dc">11.17% mIoU over frozen SAM2).
      class="dc">2. Cross-scale feature fusion in decoder — explicitly aligns high-level
         semantic features with fine-grained boundary information.
      class="dc">3. Hybrid BCE + IoU-aware loss — jointly optimises pixel accuracy and
         geometric consistency (Strategy class="dc">3: +class="dc">3.22% mIoU).
      class="dc">4. Normalised coordinate prompts — eliminates input-size bias in the
         prompt encoder (part of Strategy class="dc">4).

    Architecture flow:
      Image → ImageEncoder → MemoryEncoder → (+ PromptEncoder) → MaskDecoder → Masks + IoU scores

    Parameters
    ----------
    config : SAMMConfig instance
    """

    def __init__(self, config: Optional[SAMMConfig] = None):
        super().__init__()
        cfg = config or SAMMConfig()
        self.cfg = cfg

        D = cfg.embed_dim

        # Component modules
        self.image_encoder = SAMMImageEncoder(
            img_size=cfg.img_size,
            patch_size=cfg.patch_size,
            in_channels=cfg.in_channels,
            embed_dim=D,
            depth=cfg.encoder_depth,
            num_heads=cfg.encoder_heads,
            mlp_ratio=cfg.mlp_ratio,
        )
        self.memory_encoder = MemoryEncoder(
            embed_dim=D // class="dc">2,
            memory_dim=cfg.decoder_dim,
            depth=cfg.memory_depth,
            num_heads=class="dc">8,
        )
        self.prompt_encoder = PromptEncoder(
            embed_dim=cfg.prompt_embed_dim,
            img_size=cfg.img_size,
        )
        self.mask_decoder = SAMMMaskDecoder(
            embed_dim=D // class="dc">2,
            prompt_dim=cfg.prompt_embed_dim,
            num_mask_tokens=cfg.num_mask_tokens,
            encoder_s1_dim=D,
            encoder_s3_dim=D,
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=class="dc">0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
                if m.weight is not None: nn.init.ones_(m.weight)
                if m.bias is not None:   nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(
        self,
        images: Tensor,
        points: Optional[Tuple[Tensor, Tensor]] = None,
        boxes:  Optional[Tensor]  = None,
        masks:  Optional[Tensor]  = None,
        memory: Optional[Tensor]  = None,
    ) -> Tuple[Tensor, Tensor]:
        """
        Forward pass through the full SAMM pipeline.

        Parameters
        ----------
        images  : (B, C, H, W) — material micrograph (SEM/TEM/OM/XCT)
        points  : optional tuple (coords, labels) for point prompts
        boxes   : optional (B, class="dc">4) bounding box prompts
        masks   : optional (B, class="dc">1, H, W) dense mask prompts
        memory  : optional (B, M, D) memory bank tokens

        Returns
        -------
        masks_pred : (B, num_mask_tokens, H, W) — segmentation logits
        iou_pred   : (B, num_mask_tokens) — IoU confidence scores
        """
        # Stage 1: Extract multi-scale features (all params unfrozen)
        enc_feats = self.image_encoder(images)

        # Stage 2: Refine with memory context
        img_emb = self.memory_encoder(enc_feats['image_embedding'], memory)

        # Stage 3: Encode prompts
        sparse_emb, dense_emb = self.prompt_encoder(points=points, boxes=boxes, masks=masks)

        # Stage 4: Decode masks
        masks_pred, iou_pred = self.mask_decoder(
            image_embedding=img_emb,
            prompt_sparse=sparse_emb,
            encoder_feats=enc_feats,
            prompt_dense=dense_emb,
        )

        # Upsample masks to original image resolution
        masks_pred = F.interpolate(
            masks_pred, size=images.shape[-class="dc">2:], mode='bilinear', align_corners=False
        )
        return masks_pred, iou_pred

    def predict_best_mask(
        self,
        images: Tensor,
        points: Optional[Tuple[Tensor, Tensor]] = None,
        boxes: Optional[Tensor] = None,
        threshold: float = class="dc">0.0,
    ) -> Tensor:
        """
        Inference convenience method: returns the single best mask (highest IoU score).

        Parameters
        ----------
        images    : (B, C, H, W)
        points    : optional point prompts
        boxes     : optional box prompts
        threshold : logit threshold for binarising output mask

        Returns
        -------
        best_mask : (B, H, W) — binary segmentation mask
        """
        masks_pred, iou_pred = self.forward(images, points=points, boxes=boxes)
        best_idx  = iou_pred.argmax(dim=class="dc">1)              # (B,)
        best_mask = masks_pred[
            torch.arange(masks_pred.shape[class="dc">0]), best_idx
        ]                                                # (B, H, W)
        return (best_mask > threshold).float()


# ─── SECTION 8: Hybrid Loss Function ─────────────────────────────────────────

class SAMMSegLoss(nn.Module):
    """
    Binary Cross-Entropy segmentation loss (Eq. class="dc">1 of the paper).

    L_seg = -(class="dc">1/N) Σ_i [m_i · log(σ(M_prd^(i) + ε)) + (class="dc">1-m_i) · log(class="dc">1-σ(M_prd^(i) + ε))]

    The ε term prevents numerical instability with near-zero predictions,
    which is critical for low-contrast phase boundaries in SEM images.
    """

    def __init__(self, eps: float = class="dc">1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, logits: Tensor, targets: Tensor) -> Tensor:
        """
        Parameters
        ----------
        logits  : (B, H, W) — raw mask logits (before sigmoid)
        targets : (B, H, W) — binary ground-truth masks

        Returns
        -------
        loss : scalar
        """
        prob = torch.sigmoid(logits + self.eps)
        bce  = -(
            targets * torch.log(prob + self.eps)
            + (class="dc">1 - targets) * torch.log(class="dc">1 - prob + self.eps)
        )
        return bce.mean()


class SAMMIoULoss(nn.Module):
    """
    IoU-aware auxiliary loss (Eq. class="dc">2 of the paper).

    L_iou = (class="dc">1/N) Σ_i |S_prd^(i) - IoU(𝕀_{σ>class="dc">0.5}(M_prd^(i)), M_gt^(i))|

    Regresses the predicted IoU confidence toward the actual mask-GT IoU,
    suppressing fragmented predictions and mitigating class imbalance —
    both common in sparse-phase microstructure datasets.
    """

    def forward(self, iou_pred: Tensor, logits: Tensor, targets: Tensor) -> Tensor:
        """
        Parameters
        ----------
        iou_pred : (B, num_mask_tokens) — predicted IoU scores
        logits   : (B, num_mask_tokens, H, W) — mask logits
        targets  : (B, H, W) — binary GT mask (applied to best mask)

        Returns
        -------
        loss : scalar
        """
        B, K, H, W = logits.shape
        targets_exp = targets.unsqueeze(class="dc">1).expand_as(logits)  # (B, K, H, W)

        # Binarise predictions at sigmoid > 0.5
        pred_bin = (torch.sigmoid(logits) > class="dc">0.5).float()

        # Compute actual IoU per batch item per mask token
        eps = class="dc">1e-5
        inter = (pred_bin * targets_exp).sum(dim=(-class="dc">2, -class="dc">1))         # (B, K)
        union = ((pred_bin + targets_exp) > class="dc">0).float().sum(dim=(-class="dc">2, -class="dc">1))
        actual_iou = (inter + eps) / (union + eps)                 # (B, K)

        return torch.abs(iou_pred - actual_iou.detach()).mean()


class SAMMHybridLoss(nn.Module):
    """
    Total hybrid loss (Eq. class="dc">3 of the paper).

    L_total = L_seg + λ · L_iou,   λ = class="dc">0.05 (from grid search)

    Jointly captures fine-grained boundaries (via BCE) and global
    morphological consistency (via IoU-aware regression).
    """

    def __init__(self, lambda_iou: float = class="dc">0.05):
        super().__init__()
        self.seg_loss = SAMMSegLoss()
        self.iou_loss = SAMMIoULoss()
        self.lambda_iou = lambda_iou

    def forward(
        self,
        masks_pred: Tensor,
        iou_pred:   Tensor,
        targets:    Tensor,
    ) -> Tuple[Tensor, Dict[str, float]]:
        """
        Parameters
        ----------
        masks_pred : (B, K, H, W) — all mask logits
        iou_pred   : (B, K) — predicted IoU scores
        targets    : (B, H, W) — binary ground-truth masks

        Returns
        -------
        total_loss  : scalar Tensor
        loss_detail : dict with individual loss values for logging
        """
        # Use best-IoU-scoring mask for segmentation loss
        best_idx  = iou_pred.argmax(dim=class="dc">1)                       # (B,)
        best_mask = masks_pred[torch.arange(masks_pred.shape[class="dc">0]), best_idx]  # (B, H, W)

        l_seg = self.seg_loss(best_mask, targets.float())
        l_iou = self.iou_loss(iou_pred, masks_pred, targets.float())
        total = l_seg + self.lambda_iou * l_iou

        return total, {'seg': l_seg.item(), 'iou': l_iou.item(), 'total': total.item()}


# ─── SECTION 9: Evaluation Metrics ───────────────────────────────────────────

def compute_miou(pred_mask: Tensor, gt_mask: Tensor, eps: float = class="dc">1e-5) -> float:
    """
    Mean Intersection-over-Union for a single binary mask pair.

    mIoU = IoU_fg + IoU_bg / class="dc">2  (standard binary segmentation mIoU)

    Parameters
    ----------
    pred_mask : (H, W) binary prediction
    gt_mask   : (H, W) binary ground truth

    Returns
    -------
    miou : float in [class="dc">0, class="dc">1]
    """
    pred = pred_mask.bool()
    gt   = gt_mask.bool()

    # Foreground IoU
    inter_fg = (pred & gt).float().sum()
    union_fg = (pred | gt).float().sum()
    iou_fg   = (inter_fg + eps) / (union_fg + eps)

    # Background IoU
    inter_bg = (~pred & ~gt).float().sum()
    union_bg = (~pred | ~gt).float().sum()
    iou_bg   = (inter_bg + eps) / (union_bg + eps)

    return ((iou_fg + iou_bg) / class="dc">2).item()


def compute_boundary_f1(
    pred_mask: Tensor,
    gt_mask: Tensor,
    tolerance: int = class="dc">2,
    eps: float = class="dc">1e-5,
) -> float:
    """
    Boundary F1 score (Table class="dc">5 in the paper, tolerance = class="dc">2 pixels).

    Evaluates precision and recall of boundary pixel predictions within
    a `tolerance`-pixel distance. Used alongside mIoU to capture the model's
    ability to reproduce fine phase boundaries and grain edges.

    Parameters
    ----------
    pred_mask  : (H, W) binary mask
    gt_mask    : (H, W) binary mask
    tolerance  : pixel tolerance for boundary matching

    Returns
    -------
    bf1 : float in [class="dc">0, class="dc">1]
    """
    def _extract_boundary(mask: Tensor) -> Tensor:
        """Morphological boundary via max-pooling trick."""
        m = mask.float().unsqueeze(class="dc">0).unsqueeze(class="dc">0)   # (class="dc">1, class="dc">1, H, W)
        dilated  = F.max_pool2d(m, kernel_size=class="dc">3, stride=class="dc">1, padding=class="dc">1)
        boundary = (dilated - m).squeeze() > class="dc">0
        return boundary

    def _dilate(mask: Tensor, radius: int) -> Tensor:
        m = mask.float().unsqueeze(class="dc">0).unsqueeze(class="dc">0)
        k = class="dc">2 * radius + class="dc">1
        return (F.max_pool2d(m, k, stride=class="dc">1, padding=radius).squeeze() > class="dc">0)

    pred_bd = _extract_boundary(pred_mask)
    gt_bd   = _extract_boundary(gt_mask)

    gt_dilated   = _dilate(gt_bd, tolerance)
    pred_dilated = _dilate(pred_bd, tolerance)

    precision = (pred_bd & gt_dilated).float().sum() / (pred_bd.float().sum() + eps)
    recall    = (gt_bd & pred_dilated).float().sum() / (gt_bd.float().sum() + eps)

    bf1 = (class="dc">2 * precision * recall / (precision + recall + eps)).item()
    return bf1


class MicroscopyMetrics:
    """Accumulates mIoU and Boundary F1 across a validation epoch."""

    def __init__(self):
        self.miou_sum  = class="dc">0.0
        self.bf1_sum   = class="dc">0.0
        self.count     = class="dc">0

    class="dc">@torch.no_grad()
    def update(self, pred_logits: Tensor, gt_masks: Tensor):
        """pred_logits: (B, K, H, W), gt_masks: (B, H, W)"""
        B = pred_logits.shape[class="dc">0]
        # Use argmax-selected best mask (single output per sample)
        best_masks = (pred_logits.mean(dim=class="dc">1) > class="dc">0).float()   # (B, H, W)
        for b in range(B):
            self.miou_sum += compute_miou(best_masks[b], gt_masks[b])
            self.bf1_sum  += compute_boundary_f1(best_masks[b], gt_masks[b])
            self.count += class="dc">1

    def result(self) -> Dict[str, float]:
        n = max(class="dc">1, self.count)
        return {'mIoU': self.miou_sum / n, 'BoundaryF1': self.bf1_sum / n}

    def reset(self):
        self.miou_sum = class="dc">0.0
        self.bf1_sum  = class="dc">0.0
        self.count    = class="dc">0


# ─── SECTION 10: Dataset Helpers ─────────────────────────────────────────────

# Material dataset metadata matching Table 2 of the paper
DATASET_META = {
    # Name: (in_channels, img_size, description)
    'Data1':  (class="dc">1, class="dc">512, 'Superalloy η/σ phases (SEM-BSE, class="dc">2501×class="dc">2501)'),
    'Data2':  (class="dc">1, class="dc">512, 'Superalloy γ′ phase (SEM-SE, class="dc">800°C anneal)'),
    'Data3':  (class="dc">1, class="dc">512, 'Superalloy γ′ phase (SEM-SE, class="dc">900°C anneal)'),
    'Data4':  (class="dc">1, class="dc">512, 'Superalloy γ′ phase (SEM-SE, class="dc">1000°C anneal)'),
    'Data5':  (class="dc">1, class="dc">512, 'Ni-Co superalloy γ′ with Nb/Ta additions'),
    'Data6':  (class="dc">1, class="dc">512, 'Ni wrought superalloy tri-modal γ′ distribution'),
    'Data7':  (class="dc">1, class="dc">512, 'IN7class="dc">18 AM powder SEM (PREP + gas atomisation)'),
    'Data8':  (class="dc">1, class="dc">512, 'Rare metal powders: Pt, PtRh3class="dc">0 (SEM)'),
    'Data9':  (class="dc">1, class="dc">500, 'V2O5 nanowires (SEM + X-ray)'),
    'Data1class="dc">0': (class="dc">1, class="dc">512, 'Ti-6Al-4V α phase (SEM)'),
    'Data1class="dc">1': (class="dc">1, class="dc">512, 'Multi-alloy γ′ benchmark (Stuckner et al.)'),
    'Data1class="dc">2': (class="dc">1, class="dc">512, 'δ/o phases nanocrystalline (Yildirim et al.)'),
    'Data1class="dc">3': (class="dc">1, class="dc">512, 'ε phase carbon steel (Bayesian SegBPIS)'),
}


class MaterialMicrographDataset(Dataset):
    """
    Minimal synthetic dataset replicating the class="dc">13-subset microscopy benchmark
    described in Table class="dc">2 of the SAMM paper.

    Replace with real data from:
      - Self-collected SEM datasets from Central South University (Data class="dc">1–class="dc">8)
      - Public datasets: Lin et al. class="dc">2022, Fotos et al. class="dc">2023, Stuckner et al. class="dc">2022,
        Yildirim & Cole class="dc">2021

    Parameters
    ----------
    dataset_name  : one of Data1–Data1class="dc">3 (controls image size / channels)
    num_samples   : number of synthetic samples
    split         : 'train', 'val', or 'test'
    """

    SPLIT_RATIOS = {'train': class="dc">0.7, 'val': class="dc">0.15, 'test': class="dc">0.15}

    def __init__(
        self,
        dataset_name: str = 'Data1',
        num_samples: int = class="dc">100,
        split: str = 'train',
    ):
        self.dataset_name = dataset_name
        meta = DATASET_META.get(dataset_name, (class="dc">1, class="dc">512, 'unknown'))
        self.in_channels = meta[class="dc">0]
        self.img_size    = meta[class="dc">1]
        self.desc        = meta[class="dc">2]
        self.split       = split

        # Determine split range
        n_train = int(num_samples * self.SPLIT_RATIOS['train'])
        n_val   = int(num_samples * self.SPLIT_RATIOS['val'])
        if split == 'train':
            self.indices = list(range(n_train))
        elif split == 'val':
            self.indices = list(range(n_train, n_train + n_val))
        else:
            self.indices = list(range(n_train + n_val, num_samples))

        torch.manual_seed(hash(dataset_name) % class="dc">10000)
        total = num_samples
        self._images = torch.randn(total, self.in_channels, self.img_size, self.img_size)
        # Binary masks with realistic sparsity (20–60% positive pixels)
        fill = class="dc">0.2 + torch.rand(class="dc">1).item() * class="dc">0.4
        self._masks  = (torch.rand(total, self.img_size, self.img_size) < fill).float()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        return self._images[real_idx], self._masks[real_idx]


class CombinedDataset(Dataset):
    """
    Unified dataset combining Data class="dc">1–class="dc">7 for SAMM training
    (all class="dc">3,class="dc">490 total images across class="dc">13 subsets as described in Section class="dc">4.1).
    """

    def __init__(self, dataset_names: List[str], num_samples_per: int = class="dc">50, split: str = 'train'):
        self.datasets = [
            MaterialMicrographDataset(name, num_samples_per, split)
            for name in dataset_names
        ]
        self.cumulative = []
        running = class="dc">0
        for ds in self.datasets:
            running += len(ds)
            self.cumulative.append(running)

    def __len__(self):
        return self.cumulative[-class="dc">1]

    def __getitem__(self, idx):
        for i, end in enumerate(self.cumulative):
            if idx < end:
                start = self.cumulative[i - class="dc">1] if i > class="dc">0 else class="dc">0
                return self.datasets[i][idx - start]
        raise IndexError(f"Index {idx} out of range")


# ─── SECTION 11: Training Loop ────────────────────────────────────────────────

def build_samm_optimizer(model: nn.Module, cfg: SAMMConfig) -> torch.optim.Optimizer:
    """
    AdamW optimizer with separate parameter groups:
    - Encoder params: full learning rate (all unfrozen, per full-parameter fine-tuning)
    - Other params: same learning rate

    Paper: lr=class="dc">1e-5, weight_decay=class="dc">4e-5 (Section class="dc">4.2.class="dc">3)
    """
    return torch.optim.AdamW(
        model.parameters(),
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
        betas=(class="dc">0.9, class="dc">0.999),
    )


def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: SAMMHybridLoss,
    device: torch.device,
    epoch: int,
    grad_clip: float = class="dc">1.0,
) -> float:
    """
    Train SAMM for one epoch.

    Implements the training strategy from Section class="dc">4.2.class="dc">3:
    - Mixed-precision (FP1class="dc">6/FP3class="dc">2) if CUDA available
    - Gradient clipping for numerical stability
    - Dynamic batch handling (handled by DataLoader)

    Returns
    -------
    avg_loss : mean total loss over epoch
    """
    model.train()
    total_loss = class="dc">0.0
    scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None

    for step, (images, gt_masks) in enumerate(loader):
        images   = images.to(device)
        gt_masks = gt_masks.to(device)

        optimizer.zero_grad()

        if scaler:
            with torch.cuda.amp.autocast():
                masks_pred, iou_pred = model(images)
                loss, detail = criterion(masks_pred, iou_pred, gt_masks)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            masks_pred, iou_pred = model(images)
            loss, detail = criterion(masks_pred, iou_pred, gt_masks)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        total_loss += detail['total']
        if step % class="dc">5 == class="dc">0:
            print(
                f"  Epoch {epoch} | Step {step:3d}/{len(loader)} | "
                f"Loss {detail['total']:.4f} "
                f"(seg={detail['seg']:.4f}, iou={detail['iou']:.4f})"
            )

    return total_loss / len(loader)


class="dc">@torch.no_grad()
def validate(
    model: nn.Module,
    loader: DataLoader,
    criterion: SAMMHybridLoss,
    metrics: MicroscopyMetrics,
    device: torch.device,
) -> Tuple[float, Dict[str, float]]:
    """Evaluate SAMM on a validation split. Returns (avg_loss, metrics_dict)."""
    model.eval()
    metrics.reset()
    total_loss = class="dc">0.0

    for images, gt_masks in loader:
        images   = images.to(device)
        gt_masks = gt_masks.to(device)
        masks_pred, iou_pred = model(images)
        loss, detail = criterion(masks_pred, iou_pred, gt_masks)
        total_loss += detail['total']
        metrics.update(masks_pred, gt_masks)

    return total_loss / len(loader), metrics.result()


def run_training(
    train_datasets: List[str] = None,
    epochs: int = class="dc">3,
    batch_size: int = class="dc">2,
    device_str: str = 'cpu',
    num_samples_per: int = class="dc">20,
):
    """
    Full SAMM training pipeline on the combined Data class="dc">1–class="dc">7 microscopy dataset.

    Paper uses: epochs ~ sufficient for convergence, batch_size via dynamic batching,
    AdamW lr=class="dc">1e-5, weight_decay=class="dc">4e-5, mixed-precision on NVIDIA GPUs.
    """
    if train_datasets is None:
        train_datasets = [f'Data{i}' for i in range(class="dc">1, class="dc">8)]   # Data class="dc">1–class="dc">7

    device = torch.device(device_str)
    print(f"\n{'='*class="dc">60}")
    print(f"  SAMM Training — {len(train_datasets)} material datasets")
    print(f"  Device: {device} | Epochs: {epochs} | Batch: {batch_size}")
    print(f"{'='*class="dc">60}\n")

    # Datasets
    train_ds  = CombinedDataset(train_datasets, num_samples_per, split='train')
    val_ds    = CombinedDataset(train_datasets, num_samples_per, split='val')
    train_ldr = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=class="dc">0)
    val_ldr   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=class="dc">0)

    # Model — use a smaller config for smoke test
    cfg = SAMMConfig(img_size=class="dc">64, embed_dim=class="dc">96, encoder_depth=class="dc">3, encoder_heads=class="dc">3,
                     memory_depth=class="dc">2, decoder_dim=class="dc">64, prompt_embed_dim=class="dc">64, num_mask_tokens=class="dc">2)
    model = SAMM(cfg).to(device)
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable parameters: {params / class="dc">1e6:.2f} M")

    optimizer = build_samm_optimizer(model, cfg)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = SAMMHybridLoss(lambda_iou=cfg.lambda_iou)
    metrics   = MicroscopyMetrics()

    best_miou = class="dc">0.0
    for epoch in range(class="dc">1, epochs + class="dc">1):
        train_loss = train_one_epoch(model, train_ldr, optimizer, criterion, device, epoch)
        val_loss, val_m = validate(model, val_ldr, criterion, metrics, device)
        scheduler.step()
        print(
            f"Epoch {epoch:2d}/{epochs} | "
            f"Train: {train_loss:.4f} | Val: {val_loss:.4f} | "
            f"mIoU: {val_m['mIoU']:.4f} | BF1: {val_m['BoundaryF1']:.4f}"
        )
        if val_m['mIoU'] > best_miou:
            best_miou = val_m['mIoU']
            print(f"  ✓ New best mIoU: {best_miou:.4f}")

    print(f"\nTraining complete. Best mIoU: {best_miou:.4f}")
    return model


# ─── SECTION 12: Smoke Test ───────────────────────────────────────────────────

if __name__ == '__main__':
    print('=' * class="dc">60)
    print('SAMM — Full Architecture Smoke Test')
    print('=' * class="dc">60)
    torch.manual_seed(class="dc">42)
    device = torch.device('cpu')

    # ── 1. Instantiate with small config for fast test ────────────────────────
    print('\n[class="dc">1/class="dc">5] Building SAMM (small config for smoke test)...')
    cfg = SAMMConfig(
        img_size=class="dc">64, in_channels=class="dc">1,
        embed_dim=class="dc">96, encoder_depth=class="dc">3, encoder_heads=class="dc">3,
        memory_depth=class="dc">2, decoder_dim=class="dc">64,
        prompt_embed_dim=class="dc">64, num_mask_tokens=class="dc">2,
    )
    model = SAMM(cfg).to(device)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'  Trainable params: {n_params / class="dc">1e6:.2f} M')
    print(f'  (Full-size model ~class="dc">300–class="dc">600 M depending on ViT backbone scale)')

    # ── 2. Forward pass (no prompts — automatic segmentation mode) ────────────
    print('\n[class="dc">2/class="dc">5] Forward pass — no-prompt mode (SEM grayscale class="dc">64×class="dc">64)...')
    images = torch.randn(class="dc">2, class="dc">1, class="dc">64, class="dc">64)
    with torch.no_grad():
        masks_pred, iou_pred = model(images)
    print(f'  Input:    {tuple(images.shape)}')
    print(f'  Masks:    {tuple(masks_pred.shape)}  (B, K, H, W)')
    print(f'  IoU pred: {tuple(iou_pred.shape)}   (B, K)')
    assert masks_pred.shape == (class="dc">2, cfg.num_mask_tokens, class="dc">64, class="dc">64)

    # ── 3. Forward pass with point prompts ───────────────────────────────────
    print('\n[class="dc">3/class="dc">5] Forward pass — point-prompt mode...')
    coords  = torch.randint(class="dc">0, class="dc">64, (class="dc">2, class="dc">3, class="dc">2)).float()
    labels  = torch.ones(class="dc">2, class="dc">3).long()
    with torch.no_grad():
        masks_p, iou_p = model(images, points=(coords, labels))
    print(f'  Point-prompted masks: {tuple(masks_p.shape)}')

    # ── 4. Loss function verification ─────────────────────────────────────────
    print('\n[class="dc">4/class="dc">5] Loss function check...')
    criterion = SAMMHybridLoss(lambda_iou=class="dc">0.05)
    gt = (torch.rand(class="dc">2, class="dc">64, class="dc">64) > class="dc">0.5).long()
    loss_val, detail = criterion(masks_pred, iou_pred, gt)
    print(f'  Total loss : {loss_val.item():.4f}')
    print(f'  Seg loss   : {detail["seg"]:.4f}')
    print(f'  IoU loss   : {detail["iou"]:.4f}')

    # ── 5. Metrics ────────────────────────────────────────────────────────────
    print('\n[class="dc">4.5/class="dc">5] Boundary F1 and mIoU metric check...')
    pred_m = torch.rand(class="dc">64, class="dc">64) > class="dc">0.5
    gt_m   = torch.rand(class="dc">64, class="dc">64) > class="dc">0.5
    miou   = compute_miou(pred_m, gt_m)
    bf1    = compute_boundary_f1(pred_m, gt_m, tolerance=class="dc">2)
    print(f'  mIoU = {miou:.4f}  |  Boundary F1 = {bf1:.4f}')

    # ── 6. Short training loop ─────────────────────────────────────────────────
    print('\n[class="dc">5/class="dc">5] Short training run (class="dc">2 epochs, class="dc">3 datasets, synthetic data)...')
    run_training(
        train_datasets=['Data1', 'Data2', 'Data7'],
        epochs=class="dc">2, batch_size=class="dc">2, device_str='cpu', num_samples_per=class="dc">10
    )

    print('\n' + '=' * class="dc">60)
    print('✓  All checks passed. SAMM is ready for training.')
    print('=' * class="dc">60)
    print("""
Next steps for real training:
  class="dc">1. Replace MaterialMicrographDataset with real SEM/TEM image loaders.
  class="dc">2. Initialise image_encoder with SAM2 pretrained ViT-B/L/H weights:
       model.image_encoder.load_state_dict(sam2_weights, strict=False)
     Available from: https://github.com/facebookresearch/sam2
  class="dc">3. Set img_size=class="dc">512, embed_dim=class="dc">768 (ViT-B) or class="dc">1024 (ViT-L) for full scale.
  class="dc">4. Train on Data class="dc">1–class="dc">7 combined (class="dc">3490 images, class="dc">381,class="dc">962 masks).
  class="dc">5. Zero-shot evaluation on Data class="dc">8–class="dc">13 without any fine-tuning.
  class="dc">6. Enable mixed-precision: device_str='cuda' auto-activates FP1class="dc">6 scaler.
  class="dc">7. Paper target: class="dc">89.68% mIoU (avg over Data class="dc">1–class="dc">7 with all class="dc">4 strategies).
""")

Read the Full Paper & Access the Dataset

The complete study — including the 13-subset annotated microscopy dataset (3,490 images, 381,962 masks) — is published open-access in Advanced Powder Materials under CC BY-NC-ND 4.0.

Academic Citation:
Tu, J., Wang, Z., Li, W., Tan, L., Huang, L., & Liu, F. (2026). SAMM: A general-purpose segmentation model for material micrographs based on the segment anything model 2. Advanced Powder Materials, 5, 100404. https://doi.org/10.1016/j.apmate.2026.100404

This article is an independent editorial analysis of peer-reviewed research. The Python implementation is an educational adaptation of the paper’s methodology. The original authors provide pretrained weights and datasets via the supplementary materials of the paper. The code here is a clean-room reimplementation for pedagogical purposes and does not reproduce any proprietary training data.

Leave a Comment

Your email address will not be published. Required fields are marked *

Follow by Email
Tiktok