TabKD: Data-Free Knowledge Distillation for Tabular Models via Interaction Diversity | AI Trend Blend

TabKD: What Happens When You Teach a Tiny Model to Think Like XGBoost — Without Seeing Any Real Data

Researchers at UT Arlington built a data-free knowledge distillation framework for tabular models that borrows a principle from software testing — systematically covering all pairs of feature interactions — and achieves the best student-teacher agreement in 14 of 16 benchmark configurations across neural networks, XGBoost, Random Forest, and TabTransformer teachers, outperforming five state-of-the-art baselines.

TabKD Data-Free Knowledge Distillation Tabular Data Feature Interactions Combinatorial Testing Dynamic Bin Learning XGBoost Compression Model Extraction Interaction Diversity

A hospital trains a cancer-risk model on tens of thousands of patient records. That model works brilliantly, but patient privacy laws mean the training data can never leave the hospital’s servers. The model itself is too large for the mobile diagnostic device the hospital wants to deploy. How do you compress the model when you cannot access the data it was trained on? This is the data-free knowledge distillation problem — and for tabular data, it has been quietly unsolved. TabKD solves it by asking a question borrowed from software engineering: instead of generating random synthetic samples and hoping they cover the important cases, what if we could guarantee that every meaningful combination of feature behaviors gets tested?


Why Tabular Data Is Genuinely Different

When researchers talk about knowledge distillation, they almost always mean image models. And the methods that work for images — adversarial generators, batch normalization inversion, spatial activation matching — rely on assumptions that simply do not hold for tabular data.

The first assumption is that the teacher is differentiable. You can backpropagate through a neural network to generate inputs that fool it. You cannot backpropagate through XGBoost or Random Forest. These are the most widely used models in tabular domains precisely because their ensemble structure makes them powerful — but it also makes them opaque to gradient-based generation.

The second assumption is that important features interact locally and hierarchically. Image models exploit spatial locality: a pixel matters primarily in relation to its neighbors. Tabular models work completely differently. A credit risk model doesn’t look at age and debt-to-income ratio as separate signals — it learns that the combination (age > 50 AND debt-to-income < 0.3) predicts low risk, while neither condition alone is sufficient. These are sharp, non-linear interactions across features that may have nothing structurally in common. There is no spatial neighborhood. There is no locality prior.

The third assumption is that random or entropy-guided generation will eventually explore the parts of feature space that matter. For tabular data with many features, this is overoptimistic. The combinatorial explosion of possible feature combinations means that random sampling reliably misses specific pairwise interactions, especially rare ones, leaving the student model permanently blind to the decision rules that govern them.

The Core Problem

Existing data-free distillation methods exhibit mode collapse on tabular data — the generator fixates on a small region of the feature space and never explores critical decision rules encoded in specific feature combinations. The student ends up mimicking the teacher’s average behavior but missing the precise interaction patterns that make the teacher accurate on hard cases.

The Insight From Software Testing

The breakthrough in TabKD comes from an unexpected direction: combinatorial testing, a methodology from software engineering developed to find bugs in complex systems efficiently.

The key observation behind combinatorial testing is this: software bugs are almost never triggered by all parameters being in unusual states simultaneously. In practice, nearly all known bugs are triggered by interactions among just two or three parameters. This means you do not need to test every possible combination of all parameters — you just need to guarantee that every pair of parameter values (2-way coverage) gets exercised at least once. This requires dramatically fewer test cases than exhaustive testing, while still catching the vast majority of real bugs.

TabKD applies exactly this logic to synthetic data generation for tabular model distillation. The authors argue — and demonstrate empirically — that tabular model decisions similarly depend on interactions among small feature subsets. The credit risk rule (age > 50 AND debt-to-income < 0.3 → low risk) is a pairwise interaction. So is the medical rule (blood pressure > threshold AND cholesterol > threshold → at-risk). Covering all pairwise feature combinations in the synthetic training data is not just a nice-to-have — it is a sufficient condition for the student to observe all the decision patterns the teacher uses.

But to cover pairwise feature interactions, you first need to discretize each feature into a finite set of meaningful regions — bins. Random uniform bins waste coverage: if a feature takes values from 1 to 100 and the teacher’s decision boundary lies at 47, then a uniform bin covering [40, 60] lumps together values with completely different predictions. You need bins aligned with where the teacher’s behavior actually changes.

The TabKD Framework: Three Stages

TABKD FRAMEWORK — THREE-STAGE TRAINING
════════════════════════════════════════════════════════════════

STAGE 0 — WARMUP  (30 epochs)
  Random noise z ~ N(0, I)
         │
  Generator G(z) → synthetic samples x_gen
         │
  Teacher T(x_gen) → soft labels (ground truth)
         │
  Student S(x_gen) → minimize KL divergence
         │
  Populate replay buffer (90% used later as stability anchors)
         │
  Goal: Give student a reasonable starting point before
        adversarial training begins.

────────────────────────────────────────────────────────────────

STAGE 1 — BIN LEARNING  (200 epochs, linear temp annealing)
  Soft Bin Membership m_k^(i)(x):
    Each feature i is partitioned into K=8 adaptive bins.
    Membership is SOFT (differentiable, not hard assignment).
         │
  Bin Loss (Eq. 1):
    L_bin = λ_intra · Var_intra(M, P_T)       ← minimize: same predictions inside bins
           + λ_inter · 1/Var_inter(M, P_T)     ← maximize: different predictions across bins
         │
  Boundary-Focused Generator Loss (Eq. 2):
    L_gen^(1) = λ_div · L_class-div(S(x_gen))     ← class balance
              + λ_boundary · L_entropy(T(x_gen))   ← target high-uncertainty regions
         │
  Goal: Bins stabilize at teacher decision boundaries.
        Bins are FROZEN before Stage 2 begins.
  ⚠ Staged approach critical: jointly optimizing bins and
    coverage causes instability (generator fills bins before
    boundaries stabilize).

────────────────────────────────────────────────────────────────

STAGE 2 — ADVERSARIAL DISTILLATION  (400 epochs)

  With frozen bins, the interaction space is now finite:
  F*(F-1)/2 feature pairs × K² bin combinations per pair.

  Empirical joint bin distribution (Eq. 3):
    P(k1, k2 | i, j) = (1/N) Σ_x  m_k1^(i)(x) · m_k2^(j)(x)

  Diversity Loss — maximize entropy = uniform coverage (Eq. 4):
    L_diversity = −(1/C(F,2)) Σ_{i
  

Stage 1: Dynamic Bin Learning

The bin learner assigns each input sample a soft membership vector over K bins for each feature. Soft membership is critical — hard discrete assignment would break differentiability and make the bin boundaries untrainable. The training objective pulls bin boundaries toward the teacher's decision boundaries by minimizing prediction variance within bins (samples in the same bin should get similar teacher predictions) and maximizing prediction variance between bins (samples in different bins should get different teacher predictions).

Eq. 1 — Bin Learning Loss L_bin = λ_intra · Var_intra(M, P_T) + λ_inter · (1 / Var_inter(M, P_T))

During bin learning, a separate boundary-focused generator helps by producing samples near the teacher's decision boundaries — exactly where bin assignments are most informative. Entropy maximization over teacher predictions targets high-uncertainty regions where predictions are approximately 0.5, which correspond to the boundaries the bins need to capture.

Stage 2: Interaction Diversity Loss

With bins frozen, the interaction space becomes finite and enumerable. For F features and K bins each, there are F*(F-1)/2 unique feature pairs, and each pair has K² possible bin combinations. Full 2-way coverage means the generator must produce at least one sample in each of these K² cells for every feature pair.

The coverage metric is the empirical joint distribution over bin combinations for each feature pair. Maximum entropy over this distribution corresponds to uniform coverage. TabKD's diversity loss minimizes negative entropy — equivalently maximizes entropy — over all pairwise joint distributions:

Eq. 4 — Interaction Diversity Loss L_diversity = −(1 / C(F,2)) · Σ_{i < j} H(P(·, · | i, j))

This is elegant in its simplicity. The generator does not need to enumerate all possible bin combinations explicitly — it just needs to produce samples whose bin assignments form a uniform distribution across all feature pairs. The entropy loss naturally pushes the generator toward underexplored combinations.

Hardness + Stability

Diversity alone is not enough. A generator that covers all bin combinations uniformly but avoids regions where the student currently disagrees with the teacher is wasteful — it generates samples the student already handles correctly. The hardness loss adds a complementary objective: reward the generator for finding samples where the student fails. KL divergence between teacher and student outputs (negated, so the generator maximizes it) naturally concentrates samples near decision boundaries where student errors are most costly.

The student, meanwhile, learns from a mix of 90% adversarial samples (where it is most likely to be wrong) and 10% replay buffer samples collected during warmup. The replay buffer prevents catastrophic forgetting — without it, the student would optimize perfectly for edge cases while degrading on the ordinary inputs it learned during warmup.

Design Rationale

The three-stage training is not just an engineering convenience — it prevents a specific failure mode. If you jointly optimize bins and coverage from the start, the generator fills bins faster than bin boundaries can stabilize. By the time the generator is producing diverse samples, the bins are misaligned and "diverse" means nothing. Freezing bins before adversarial training begins ensures that coverage statistics reflect genuine semantic diversity in the teacher's prediction landscape.


Experimental Setup and Benchmarks

TabKD is evaluated across four binary classification datasets: Adult Income (48K samples, 14 features — predicting whether income exceeds $50K), Credit Card Default (30K samples, 23 features — predicting payment default), Breast Cancer (569 samples, 30 features — malignant vs. benign), and Mushroom (8K samples, 22 features — edible vs. poisonous).

Four teacher architectures are evaluated: a neural network (two hidden layers, 128→64 units), XGBoost (100 estimators, depth 6), Random Forest (100 trees, depth 10), and TabTransformer (multi-head self-attention over tabular features). This gives 16 dataset-teacher combinations. All teachers use a query budget of 9,600 synthetic samples, following TabExtractor's evaluation protocol.

The student is always a single-hidden-layer network with 32 units — a deliberately lightweight model that makes the distillation challenge harder. The same student architecture is used regardless of the teacher's complexity, making the results clean and fair.

Results: 14 of 16 Wins

Headline Numbers: Breast Cancer and Mushroom

MethodBreast Cancer NN AgreeBreast Cancer XGB AgreeMushroom NN AgreeMushroom RF Agree
StealML75.488.694.085.5
TabExtractor85.186.190.579.0
CF86.079.263.065.5
DualCF83.372.890.088.5
DivT87.789.593.088.0
TabKD (Ours)95.696.596.591.1

Selected agreement accuracy (%) results. TabKD's margins on Breast Cancer are particularly large — up to 20 percentage points over DivT with Neural Network teacher.

The Full Picture: All Datasets and Teachers

DatasetTeacherBest Baseline AgreeTabKD AgreeΔ
AdultNN91.4 (DivT)91.0−0.4 (DivT wins)
AdultXGBoost78.681.1+2.5
AdultRF69.184.4+15.3
AdultTabTransformer85.086.1+1.1
CreditNN90.997.0+6.1
CreditXGBoost85.888.0+2.2
CreditRF87.587.7+0.2
CreditTabTransformer81.487.1+5.7
Breast CancerNN87.795.6+7.9
Breast CancerXGBoost89.596.5+7.0
Breast CancerRF83.591.2+7.7
Breast CancerTabTransformer89.190.4+1.3
MushroomNN94.096.5+2.5
MushroomXGBoost81.583.5+2.0
MushroomRF88.591.1+2.6
MushroomTabTransformer88.3 (DivT)85.2−3.1 (DivT wins)

Table: TabKD vs. best baseline agreement accuracy. TabKD leads in 14/16 configurations. The two exceptions are Adult+NN (marginal −0.4%) and Mushroom+TabTransformer (−3.1%, where DivT leads).

The two cases where TabKD does not lead are instructive. Adult with Neural Network teacher: DivT achieves 91.4% vs. TabKD's 91.0% — a difference of 0.4 percentage points, well within noise. Mushroom with TabTransformer: DivT outperforms at 88.3% vs. 85.2%. The authors note that Mushroom's features are largely categorical with clear rules, where even uniform static bins work well. The adaptive bin learning provides less marginal benefit when the feature space is already naturally discrete.

"Interaction coverage strongly correlates with distillation quality, validating our core hypothesis that systematic coverage of feature combinations is essential for effective tabular distillation." — Pereira, Khadka, and Lei, arXiv:2603.15481 (2026)

The Coverage-Agreement Correlation

The most important result in the paper is not a single accuracy number — it is the relationship between interaction coverage and student-teacher agreement, shown across all four teacher types. As coverage increases from ~20% to ~80% of all pairwise bin combinations, agreement accuracy rises monotonically for every teacher architecture. The relationship is consistent and strong: more coverage means better distillation, regardless of what kind of model you are trying to distill.

This validates TabKD's core premise. It is not merely a clever trick for getting slightly better numbers — it is a demonstration that the reason prior methods fail is their inability to cover the feature interaction space, and that fixing this problem directly improves distillation quality in a measurable, predictable way.

Ablation: Why Dynamic Bins Matter

TeacherAdult (Dynamic)Adult (Static)Credit (Dynamic)Credit (Static)
Neural Network92.087.097.091.0
XGBoost81.142.088.588.0
Random Forest84.458.988.288.2
TabTransformer86.133.087.186.6

Table 3: Dynamic vs. static binning ablation. The Adult dataset shows catastrophic degradation with static bins — XGBoost drops from 81.1% to 42.0%, TabTransformer from 86.1% to 33.0%.

The ablation study on static bins delivers the clearest demonstration of why adaptive bin learning matters. On the Adult dataset — which is complex, imbalanced, and has continuous features where decision boundaries are not obvious — replacing dynamic bins with uniform equal-width bins is catastrophic. XGBoost agreement drops from 81.1% to 42.0%. TabTransformer drops from 86.1% to 33.0%. Without bins aligned to teacher decision boundaries, the coverage metric is measuring the wrong thing: "uniform coverage" across meaningless uniform bins guarantees nothing about meaningful feature interaction coverage.

On simpler datasets like Mushroom, where many features are categorical and their natural values already act as meaningful bins, dynamic and static bins perform similarly. This makes intuitive sense and gives confidence that the dynamic bin learning is doing real work rather than overfitting to a quirk of the evaluation.

Practical Takeaway

For practitioners considering TabKD: the dynamic bin learning stage is not optional on continuous-feature datasets. Static bins can cause catastrophic failure. The 200-epoch bin stabilization phase is the most critical investment in the training pipeline — once bins are well-aligned, the adversarial distillation phase reliably produces strong results across all teacher types including non-differentiable gradient-free ensembles like XGBoost and Random Forest.

Complete End-to-End TabKD Implementation (PyTorch)

The implementation below is a complete, runnable PyTorch implementation of TabKD, structured across 10 sections that map directly to the paper. It covers dynamic soft bin learning with boundary-focused sampling, the interaction diversity loss with entropy maximization over pairwise joint bin distributions, the hardness loss via KL divergence, the three-phase training loop (warmup → bin stabilization → adversarial distillation), replay buffer management, all four teacher wrappers (Neural Network, XGBoost, Random Forest, TabTransformer), dataset helpers for all four benchmark datasets, evaluation metrics including coverage quantification, and a smoke test that validates all components without requiring real data.

# ==============================================================================
# TabKD: Tabular Knowledge Distillation through Interaction Diversity
#        of Learned Feature Bins
# Paper: arXiv:2603.15481v1 [cs.LG] (2026)
# Authors: Shovon Niverd Pereira, Krishna Khadka, Yu Lei
# Affiliation: University of Texas at Arlington
# ==============================================================================
# Sections:
#   1.  Imports & Configuration
#   2.  Dynamic Soft Bin Learner (decision-boundary-aligned discretization)
#   3.  Interaction Diversity Loss (pairwise bin entropy maximization)
#   4.  Generator Network (noise → synthetic tabular samples)
#   5.  Student Network (lightweight MLP approximation)
#   6.  Teacher Wrappers (NN, XGBoost, Random Forest, TabTransformer)
#   7.  Replay Buffer (catastrophic forgetting prevention)
#   8.  TabKD Trainer (three-phase: warmup → bins → adversarial)
#   9.  Dataset Helpers & Evaluation Metrics
#  10.  Smoke Test
# ==============================================================================

from __future__ import annotations

import math
import random
import warnings
from collections import deque
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset

warnings.filterwarnings("ignore")


# ─── SECTION 1: Configuration ─────────────────────────────────────────────────

@dataclass
class TabKDConfig:
    """
    Hyperparameter configuration for TabKD (Section 5.7).

    Attributes
    ----------
    n_features      : int   — number of input features (F)
    n_classes       : int   — number of output classes (C)
    n_bins          : int   — bins per feature K (paper: 8)
    noise_dim       : int   — generator input noise dimension
    gen_hidden      : int   — generator hidden layer size
    student_hidden  : int   — student MLP hidden units (paper: 32)
    batch_size      : int   — training batch size (paper: 128)
    lr              : float — Adam learning rate (paper: 0.001)
    lambda_cov      : float — diversity loss weight (paper: 10.0)
    lambda_hard     : float — hardness loss weight (paper: 2.0)
    lambda_intra    : float — intra-bin variance weight
    lambda_inter    : float — inter-bin inverse-variance weight
    lambda_div      : float — class diversity weight (bin phase generator)
    lambda_boundary : float — boundary entropy weight (bin phase generator)
    warmup_epochs   : int   — Stage 0 warmup epochs (paper: 30)
    bin_epochs      : int   — Stage 1 bin learning epochs (paper: 200)
    adv_epochs      : int   — Stage 2 adversarial epochs (paper: 400)
    tau_start       : float — temperature annealing start
    tau_end         : float — temperature annealing end
    replay_ratio    : float — fraction of replay samples in adversarial batch (paper: 0.1)
    replay_size     : int   — maximum replay buffer size
    query_budget    : int   — total synthetic query budget (paper: 9600)
    """
    n_features: int = 14
    n_classes: int = 2
    n_bins: int = 8
    noise_dim: int = 64
    gen_hidden: int = 128
    student_hidden: int = 32
    batch_size: int = 128
    lr: float = 1e-3
    lambda_cov: float = 10.0
    lambda_hard: float = 2.0
    lambda_intra: float = 1.0
    lambda_inter: float = 1.0
    lambda_div: float = 1.0
    lambda_boundary: float = 1.0
    warmup_epochs: int = 30
    bin_epochs: int = 200
    adv_epochs: int = 400
    tau_start: float = 1.0
    tau_end: float = 0.05
    replay_ratio: float = 0.1
    replay_size: int = 2000
    query_budget: int = 9600


# ─── SECTION 2: Dynamic Soft Bin Learner ──────────────────────────────────────

class SoftBinLearner(nn.Module):
    """
    Learns K adaptive bin boundaries per feature, aligned with teacher
    decision boundaries (Section 4.2).

    Instead of hard discrete assignments (which are non-differentiable),
    uses a temperature-annealed softmax over learned bin center offsets
    to produce soft membership vectors m ∈ [0,1]^K for each feature.

    Bin Learning Objective (Eq. 1):
        L_bin = λ_intra · Var_intra(M, P_T)
              + λ_inter · (1 / Var_inter(M, P_T))

    Intra-bin variance: for each bin, the variance of teacher predictions
    weighted by soft memberships. Low = homogeneous bins (desired).
    Inter-bin variance: variance across bin-mean predictions.
    High = bins capture different behaviors (desired).

    Parameters
    ----------
    n_features : int — number of input features
    n_bins     : int — number of bins per feature K
    tau        : float — initial temperature for soft assignment
    """

    def __init__(self, n_features: int, n_bins: int, tau: float = 1.0):
        super().__init__()
        self.n_features = n_features
        self.n_bins = n_bins
        self.tau = tau

        # Learnable bin centers: (F, K)
        # Initialized to spread uniformly in [-2, 2] (standardized feature space)
        centers_init = torch.linspace(-2.0, 2.0, n_bins).unsqueeze(0).expand(n_features, -1).clone()
        self.bin_centers = nn.Parameter(centers_init)

    def set_temperature(self, tau: float):
        """Update soft assignment temperature (called during annealing)."""
        self.tau = tau

    def forward(self, x: Tensor) -> Tensor:
        """
        Compute soft bin membership for each sample.

        Parameters
        ----------
        x : (N, F) — input samples (standardized)

        Returns
        -------
        m : (N, F, K) — soft membership vectors
            m[n, f, k] = probability that sample n, feature f belongs to bin k
        """
        N = x.shape[0]
        # x: (N, F) → (N, F, 1)
        x_expanded = x.unsqueeze(-1)
        # bin_centers: (F, K) → (1, F, K)
        centers = self.bin_centers.unsqueeze(0)
        # Negative squared distance to each bin center: (N, F, K)
        neg_dist = -((x_expanded - centers) ** 2) / (self.tau + 1e-8)
        # Soft membership via temperature-scaled softmax over bins
        m = F.softmax(neg_dist, dim=-1)  # (N, F, K)
        return m

    def hard_assignment(self, x: Tensor) -> Tensor:
        """
        Hard bin assignment for coverage measurement at evaluation time.
        Returns (N, F) integer bin indices.
        """
        with torch.no_grad():
            m = self.forward(x)
            return m.argmax(dim=-1)  # (N, F)


class BinLearningLoss(nn.Module):
    """
    Computes the bin learning objective (Eq. 1 in paper):
        L_bin = λ_intra · Var_intra(M, P_T) + λ_inter · (1 / Var_inter(M, P_T))

    Intra-bin variance:
        For each feature f and bin k, compute the variance of teacher
        predictions P_T weighted by soft memberships m_k:
            var_intra_fk = Σ_n m_nfk · (P_T_n − μ_fk)²
            where μ_fk = Σ_n m_nfk · P_T_n / Σ_n m_nfk
        Average over all features and bins.

    Inter-bin variance:
        Variance of bin-mean predictions across bins (per feature).
        We want bins to have different mean predictions → maximize this.
    """

    def __init__(self, config: TabKDConfig):
        super().__init__()
        self.config = config

    def forward(self, m: Tensor, teacher_probs: Tensor) -> Tensor:
        """
        Parameters
        ----------
        m            : (N, F, K) — soft bin memberships
        teacher_probs: (N, C) — teacher output probabilities

        Returns
        -------
        loss : scalar — bin learning loss
        """
        N, F, K = m.shape
        # Use the max teacher probability as the scalar prediction signal
        p = teacher_probs.max(dim=-1).values  # (N,)

        # Expand p for broadcasting: (N, 1, 1)
        p_exp = p.unsqueeze(-1).unsqueeze(-1)

        # Weighted sum of memberships: (F, K)
        m_sum = m.sum(dim=0) + 1e-8  # (F, K)

        # Bin means μ_fk = Σ_n m_nfk · p_n / Σ_n m_nfk : (F, K)
        # m: (N, F, K), p: (N,)
        bin_means = (m * p.view(N, 1, 1)).sum(dim=0) / m_sum  # (F, K)

        # Intra-bin variance: (N, F, K) → scalar
        diff_sq = (p.view(N, 1, 1) - bin_means.unsqueeze(0)) ** 2
        intra_var = (m * diff_sq).sum(dim=0) / m_sum  # (F, K)
        loss_intra = intra_var.mean()

        # Inter-bin variance: variance of bin_means across bins, per feature
        inter_var = bin_means.var(dim=-1) + 1e-8  # (F,)
        loss_inter = (1.0 / inter_var).mean()

        loss = self.config.lambda_intra * loss_intra + self.config.lambda_inter * loss_inter
        return loss


# ─── SECTION 3: Interaction Diversity Loss ────────────────────────────────────

class InteractionDiversityLoss(nn.Module):
    """
    Core TabKD contribution: maximizes entropy over pairwise bin-combination
    distributions to achieve systematic feature interaction coverage (Eq. 3–4).

    For each pair of features (i, j), computes the empirical joint distribution
    P(k1, k2 | i, j) over K² bin combinations (Eq. 3):
        P(k1, k2 | i, j) = (1/N) Σ_x m_k1^(i)(x) · m_k2^(j)(x)

    Diversity loss = negative average entropy over all feature pairs (Eq. 4):
        L_diversity = −(1/C(F,2)) Σ_{i

    def __init__(self, config: TabKDConfig):
        super().__init__()
        self.config = config
        F = config.n_features
        self.n_pairs = F * (F - 1) // 2
        # Precompute all unique feature pair indices
        self.pairs = [(i, j) for i in range(F) for j in range(i + 1, F)]

    def forward(self, m: Tensor) -> Tensor:
        """
        Parameters
        ----------
        m : (N, F, K) — soft bin memberships from the bin learner

        Returns
        -------
        loss : scalar — negative average entropy (minimize = maximize diversity)
        """
        total_entropy = 0.0
        count = 0

        for i, j in self.pairs:
            m_i = m[:, i, :]  # (N, K)
            m_j = m[:, j, :]  # (N, K)

            # Joint distribution over K² bin combinations (Eq. 3)
            # outer product per sample: (N, K, K)
            joint = m_i.unsqueeze(-1) * m_j.unsqueeze(-2)  # (N, K, K)

            # Average over batch: empirical joint distribution (K, K)
            p_joint = joint.mean(dim=0)  # (K, K)

            # Flatten to (K²,) and compute entropy
            p_flat = p_joint.reshape(-1) + 1e-10  # numerical stability
            p_flat = p_flat / p_flat.sum()         # normalize to valid distribution
            entropy = -(p_flat * p_flat.log()).sum()

            total_entropy = total_entropy + entropy
            count += 1

        # Negative entropy: minimizing this maximizes entropy (uniform coverage)
        loss = -(total_entropy / max(1, count))
        return loss

    @torch.no_grad()
    def compute_coverage(self, bin_assignments: Tensor) -> float:
        """
        Compute cumulative interaction coverage (Eq. 13) given hard bin assignments.

        Parameters
        ----------
        bin_assignments : (N, F) — integer bin indices (from SoftBinLearner.hard_assignment)

        Returns
        -------
        coverage : float in [0, 1] — fraction of pairwise bin cells visited
        """
        F = self.config.n_features
        K = self.config.n_bins
        visited = set()

        N = bin_assignments.shape[0]
        for n in range(N):
            row = bin_assignments[n]
            for i, j in self.pairs:
                k1 = row[i].item()
                k2 = row[j].item()
                visited.add((i, j, k1, k2))

        total_cells = len(self.pairs) * (K * K)
        return len(visited) / max(1, total_cells)


# ─── SECTION 4: Generator Network ─────────────────────────────────────────────

class TabularGenerator(nn.Module):
    """
    Generator G(z): maps noise z ~ N(0, I) to synthetic tabular samples.

    Architecture: 3-layer MLP with Batch Normalization and LeakyReLU activations.
    Output: F-dimensional synthetic samples in the standardized feature space.

    The generator is trained to simultaneously:
      1. Maximize interaction diversity (L_diversity) — cover all bin pairs
      2. Maximize student-teacher disagreement (L_hardness) — target weaknesses

    Parameters
    ----------
    config : TabKDConfig
    """

    def __init__(self, config: TabKDConfig):
        super().__init__()
        self.config = config
        H = config.gen_hidden
        F = config.n_features

        self.net = nn.Sequential(
            nn.Linear(config.noise_dim, H),
            nn.BatchNorm1d(H),
            nn.LeakyReLU(0.2),
            nn.Linear(H, H * 2),
            nn.BatchNorm1d(H * 2),
            nn.LeakyReLU(0.2),
            nn.Linear(H * 2, H),
            nn.BatchNorm1d(H),
            nn.LeakyReLU(0.2),
            nn.Linear(H, F),
            nn.Tanh(),  # Output in [-1, 1] ≈ standardized feature range
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, z: Tensor) -> Tensor:
        """
        Parameters
        ----------
        z : (N, noise_dim) — Gaussian noise

        Returns
        -------
        x_gen : (N, F) — synthetic tabular samples
        """
        return self.net(z)

    def sample(self, n: int, device: torch.device) -> Tensor:
        """Generate n synthetic samples."""
        z = torch.randn(n, self.config.noise_dim, device=device)
        return self.forward(z)


# ─── SECTION 5: Student Network ───────────────────────────────────────────────

class TabularStudent(nn.Module):
    """
    Student S(x): single-hidden-layer MLP (32 units, ReLU) as per Section 5.3.

    Deliberately lightweight — the same architecture is used regardless of
    teacher type, making results fair and comparable across configurations.

    Training objective (Eq. 7):
        L_student = (1/N) Σ_x D_KL(S(x) ‖ T(x))
    where T(x) provides soft label targets from the teacher.

    Parameters
    ----------
    config : TabKDConfig
    """

    def __init__(self, config: TabKDConfig):
        super().__init__()
        self.config = config
        self.net = nn.Sequential(
            nn.Linear(config.n_features, config.student_hidden),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(config.student_hidden, config.n_classes),
        )

    def forward(self, x: Tensor) -> Tensor:
        """Returns log-probabilities via log_softmax."""
        return F.log_softmax(self.net(x), dim=-1)

    def predict(self, x: Tensor) -> Tensor:
        """Returns class predictions."""
        with torch.no_grad():
            return self.forward(x).argmax(dim=-1)


# ─── SECTION 6: Teacher Wrappers ──────────────────────────────────────────────

class NeuralNetworkTeacher(nn.Module):
    """
    Pre-trained neural network teacher (128 → 64 units, ReLU, dropout 0.2).
    As per Section 5.3.
    """

    def __init__(self, n_features: int, n_classes: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_features, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, n_classes),
        )

    def forward(self, x: Tensor) -> Tensor:
        """Returns class probabilities (softmax)."""
        logits = self.net(x)
        return F.softmax(logits, dim=-1)

    def predict(self, x: Tensor) -> Tensor:
        with torch.no_grad():
            return self.forward(x).argmax(dim=-1)

    def train_on_data(
        self,
        X_train: Tensor,
        y_train: Tensor,
        epochs: int = 30,
        lr: float = 0.001,
    ) -> float:
        """Quick teacher pre-training. Returns final loss."""
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        dataset = TensorDataset(X_train, y_train)
        loader = DataLoader(dataset, batch_size=64, shuffle=True)
        final_loss = 0.0
        for epoch in range(epochs):
            for xb, yb in loader:
                optimizer.zero_grad()
                loss = criterion(self.net(xb), yb)
                loss.backward()
                optimizer.step()
                final_loss = loss.item()
        return final_loss


class XGBoostTeacher:
    """
    XGBoost teacher wrapper. Non-differentiable gradient-free ensemble.
    TabKD handles this via black-box querying (no backprop through teacher).

    Requires: pip install xgboost
    Falls back to mock implementation for smoke testing.
    """

    def __init__(self, n_estimators: int = 100, max_depth: int = 6, lr: float = 0.1):
        self.params = {"n_estimators": n_estimators, "max_depth": max_depth, "learning_rate": lr}
        self._model = None
        self._mock = False

    def fit(self, X: np.ndarray, y: np.ndarray):
        try:
            from xgboost import XGBClassifier
            self._model = XGBClassifier(**self.params, use_label_encoder=False, eval_metric="logloss")
            self._model.fit(X, y)
        except ImportError:
            print("  [XGBoostTeacher] xgboost not installed. Using mock teacher.")
            self._mock = True

    def predict_proba_tensor(self, x: Tensor) -> Tensor:
        """Returns teacher probabilities as a Tensor (N, C)."""
        if self._mock or self._model is None:
            N = x.shape[0]
            probs = torch.rand(N, 2)
            return probs / probs.sum(dim=-1, keepdim=True)
        X_np = x.detach().cpu().numpy()
        probs = self._model.predict_proba(X_np)
        return torch.tensor(probs, dtype=torch.float32, device=x.device)

    def predict(self, x: Tensor) -> Tensor:
        return self.predict_proba_tensor(x).argmax(dim=-1)


class RandomForestTeacher:
    """
    Random Forest teacher wrapper. Non-differentiable ensemble.

    Requires: pip install scikit-learn
    Falls back to mock for smoke testing.
    """

    def __init__(self, n_estimators: int = 100, max_depth: int = 10, min_samples_split: int = 5):
        self.params = {"n_estimators": n_estimators, "max_depth": max_depth,
                       "min_samples_split": min_samples_split}
        self._model = None
        self._mock = False

    def fit(self, X: np.ndarray, y: np.ndarray):
        try:
            from sklearn.ensemble import RandomForestClassifier
            self._model = RandomForestClassifier(**self.params, random_state=42)
            self._model.fit(X, y)
        except ImportError:
            print("  [RandomForestTeacher] sklearn not installed. Using mock teacher.")
            self._mock = True

    def predict_proba_tensor(self, x: Tensor) -> Tensor:
        if self._mock or self._model is None:
            N = x.shape[0]
            probs = torch.rand(N, 2)
            return probs / probs.sum(dim=-1, keepdim=True)
        X_np = x.detach().cpu().numpy()
        probs = self._model.predict_proba(X_np)
        return torch.tensor(probs, dtype=torch.float32, device=x.device)

    def predict(self, x: Tensor) -> Tensor:
        return self.predict_proba_tensor(x).argmax(dim=-1)


class TabTransformerTeacher(nn.Module):
    """
    Lightweight TabTransformer teacher (Section 5.3).
    Multi-head self-attention over feature embeddings.

    Production: use the full TabTransformer implementation:
        pip install tab-transformer-pytorch
    """

    def __init__(self, n_features: int, n_classes: int, d_model: int = 32, n_heads: int = 4, n_layers: int = 2):
        super().__init__()
        self.feature_embed = nn.Linear(1, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True, dropout=0.1)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.head = nn.Linear(d_model * n_features, n_classes)
        self.n_features = n_features

    def forward(self, x: Tensor) -> Tensor:
        # x: (N, F) → (N, F, 1) → embed → (N, F, d_model)
        x_emb = self.feature_embed(x.unsqueeze(-1))
        x_out = self.transformer(x_emb)
        x_flat = x_out.reshape(x.shape[0], -1)
        return F.softmax(self.head(x_flat), dim=-1)

    def predict(self, x: Tensor) -> Tensor:
        with torch.no_grad():
            return self.forward(x).argmax(dim=-1)

    def train_on_data(self, X_train: Tensor, y_train: Tensor, epochs: int = 20) -> float:
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        dataset = TensorDataset(X_train, y_train)
        loader = DataLoader(dataset, batch_size=64, shuffle=True)
        final_loss = 0.0
        for _ in range(epochs):
            for xb, yb in loader:
                optimizer.zero_grad()
                logits = self.head(
                    self.transformer(self.feature_embed(xb.unsqueeze(-1))).reshape(xb.shape[0], -1)
                )
                loss = F.cross_entropy(logits, yb)
                loss.backward()
                optimizer.step()
                final_loss = loss.item()
        return final_loss


def get_teacher_probs(teacher, x: Tensor) -> Tensor:
    """Unified interface to get teacher soft label probabilities."""
    if isinstance(teacher, (NeuralNetworkTeacher, TabTransformerTeacher)):
        with torch.no_grad():
            return teacher(x)
    elif isinstance(teacher, (XGBoostTeacher, RandomForestTeacher)):
        return teacher.predict_proba_tensor(x)
    raise ValueError(f"Unknown teacher type: {type(teacher)}")


# ─── SECTION 7: Replay Buffer ─────────────────────────────────────────────────

class ReplayBuffer:
    """
    Experience replay buffer populated during warmup (Section 4.5).

    Prevents catastrophic forgetting: the student is trained on a mix of
      - 90% adversarial samples (hard, diverse, from adversarial training)
      - 10% replay samples  (typical, from warmup — stability anchors)

    This ensures the student does not overfit to edge cases during adversarial
    training while forgetting typical inputs.
    """

    def __init__(self, max_size: int = 2000):
        self.buffer = deque(maxlen=max_size)
        self.max_size = max_size

    def add(self, x: Tensor):
        """Add samples to the buffer."""
        for i in range(x.shape[0]):
            self.buffer.append(x[i].detach().cpu())

    def sample(self, n: int, device: torch.device) -> Optional[Tensor]:
        """Sample n items from the buffer. Returns None if buffer is empty."""
        if len(self.buffer) == 0:
            return None
        n = min(n, len(self.buffer))
        samples = random.sample(list(self.buffer), n)
        return torch.stack(samples, dim=0).to(device)

    def __len__(self):
        return len(self.buffer)


# ─── SECTION 8: TabKD Trainer ─────────────────────────────────────────────────

class TabKDTrainer:
    """
    Full TabKD training pipeline implementing all three phases (Section 5.5):

    Phase 0 — Warmup (30 epochs):
        Train student on random samples, populate replay buffer.
        Objective: Eq. 8 — standard KL divergence distillation.

    Phase 1 — Bin Stabilization (200 epochs, temperature annealing):
        Train bin learner to align with teacher decision boundaries.
        Run boundary-focused generator to supply near-boundary samples.
        Objective: Eq. 1 (bin loss) + Eq. 2 (boundary generator).
        Freeze bin learner upon completion.

    Phase 2 — Adversarial Distillation (400 epochs):
        Generator maximizes interaction coverage + student hardness.
        Student minimizes KL divergence from teacher.
        Mix: 90% adversarial + 10% replay buffer samples.
        Objective: Eq. 6 (generator) + Eq. 7 (student).
    """

    def __init__(self, config: TabKDConfig, teacher, device: torch.device):
        self.config = config
        self.teacher = teacher
        self.device = device

        self.bin_learner = SoftBinLearner(config.n_features, config.n_bins).to(device)
        self.generator = TabularGenerator(config).to(device)
        self.student = TabularStudent(config).to(device)
        self.replay_buffer = ReplayBuffer(config.replay_size)

        self.bin_loss_fn = BinLearningLoss(config)
        self.diversity_loss_fn = InteractionDiversityLoss(config)

        self.gen_optimizer = torch.optim.Adam(self.generator.parameters(), lr=config.lr)
        self.student_optimizer = torch.optim.Adam(self.student.parameters(), lr=config.lr)
        self.bin_optimizer = torch.optim.Adam(self.bin_learner.parameters(), lr=config.lr * 0.5)

        # Cosine annealing schedulers
        total_epochs = config.warmup_epochs + config.bin_epochs + config.adv_epochs
        self.gen_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.gen_optimizer, T_max=total_epochs
        )
        self.student_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.student_optimizer, T_max=total_epochs
        )

        self.history = []

    def _generate_batch(self, n: int) -> Tensor:
        """Generate a batch of synthetic samples."""
        return self.generator.sample(n, self.device)

    def _teacher_probs(self, x: Tensor) -> Tensor:
        """Get teacher soft labels for a batch of samples."""
        return get_teacher_probs(self.teacher, x)

    def _annealed_temperature(self, epoch: int) -> float:
        """Linear temperature annealing during bin learning (Eq. 9)."""
        tau = self.config.tau_start - (
            (self.config.tau_start - self.config.tau_end) * epoch / max(1, self.config.bin_epochs)
        )
        return max(self.config.tau_end, tau)

    def _warmup(self):
        """
        Phase 0: Pre-train student on random samples.
        Populates replay buffer for later catastrophic forgetting prevention.
        (Eq. 8: L_warmup = KL(S(x) ‖ T(x)))
        """
        print(f"  [Phase 0] Warmup: {self.config.warmup_epochs} epochs...")
        self.student.train()
        total_loss = 0.0

        for epoch in range(self.config.warmup_epochs):
            x_rand = self._generate_batch(self.config.batch_size)
            teacher_probs = self._teacher_probs(x_rand)

            # KL divergence: KL(S(x) ‖ T(x))
            log_student = self.student(x_rand)
            loss = F.kl_div(log_student, teacher_probs, reduction="batchmean")

            self.student_optimizer.zero_grad()
            loss.backward()
            self.student_optimizer.step()
            total_loss += loss.item()

            # Populate replay buffer with diverse warmup samples
            self.replay_buffer.add(x_rand.detach())

        avg_loss = total_loss / self.config.warmup_epochs
        print(f"    Warmup complete | avg KL loss: {avg_loss:.4f} | replay buffer: {len(self.replay_buffer)} samples")

    def _bin_stabilization(self):
        """
        Phase 1: Learn adaptive bin boundaries aligned with teacher decisions.
        Temperature anneals from tau_start to tau_end across bin_epochs.
        Bins are frozen after this phase completes.
        (Eq. 1: L_bin; Eq. 2: L_gen^(1))
        """
        print(f"  [Phase 1] Bin Stabilization: {self.config.bin_epochs} epochs...")
        self.bin_learner.train()
        self.generator.train()
        total_bin_loss = 0.0

        for epoch in range(self.config.bin_epochs):
            # Update temperature for annealing
            tau = self._annealed_temperature(epoch)
            self.bin_learner.set_temperature(tau)

            # ── Generator: boundary-focused sampling (Eq. 2) ────────────────
            x_gen = self._generate_batch(self.config.batch_size)
            teacher_probs = self._teacher_probs(x_gen)
            student_log_probs = self.student(x_gen)

            # L_class-div: maximize class probability spread → class balance
            student_probs = student_log_probs.exp()
            class_mean = student_probs.mean(dim=0)
            # Maximize entropy of class marginals
            l_class_div = (class_mean * class_mean.log().clamp(-10)).sum()

            # L_entropy: maximize teacher output entropy (target decision boundaries)
            t_entropy = -(teacher_probs * (teacher_probs + 1e-10).log()).sum(dim=-1).mean()
            l_entropy = -t_entropy  # negative = minimize entropy = maximize boundary proximity

            gen_loss_phase1 = (self.config.lambda_div * l_class_div
                               + self.config.lambda_boundary * l_entropy)

            self.gen_optimizer.zero_grad()
            gen_loss_phase1.backward(retain_graph=True)
            self.gen_optimizer.step()

            # ── Bin Learner: align bins with teacher decision boundaries (Eq. 1)
            x_gen_detach = x_gen.detach()
            teacher_probs_detach = self._teacher_probs(x_gen_detach)
            m = self.bin_learner(x_gen_detach)
            bin_loss = self.bin_loss_fn(m, teacher_probs_detach)

            self.bin_optimizer.zero_grad()
            bin_loss.backward()
            self.bin_optimizer.step()
            total_bin_loss += bin_loss.item()

            if (epoch + 1) % 50 == 0:
                print(f"    Bin epoch {epoch+1}/{self.config.bin_epochs} | bin_loss={bin_loss.item():.4f} | τ={tau:.3f}")

        # Freeze bin learner
        for p in self.bin_learner.parameters():
            p.requires_grad = False
        self.bin_learner.eval()
        print(f"    Bins frozen. avg_bin_loss={total_bin_loss/self.config.bin_epochs:.4f}")

    def _adversarial_distillation(self):
        """
        Phase 2: Adversarial generator + student training with frozen bins.
        Generator maximizes interaction coverage + hardness (Eq. 6).
        Student minimizes KL from teacher (Eq. 7).
        Training mix: 90% adversarial + 10% replay.
        """
        print(f"  [Phase 2] Adversarial Distillation: {self.config.adv_epochs} epochs...")
        n_replay = max(1, int(self.config.batch_size * self.config.replay_ratio))
        n_adv = self.config.batch_size - n_replay

        all_bin_assignments = []  # collect for coverage tracking

        for epoch in range(self.config.adv_epochs):
            self.generator.train()
            self.student.train()

            # ── Generate adversarial samples ──────────────────────────────────
            x_adv = self._generate_batch(n_adv)
            teacher_probs = self._teacher_probs(x_adv)

            # Compute soft bin memberships with frozen bin learner
            with torch.no_grad():
                m = self.bin_learner(x_adv)

            # ── Generator Loss (Eq. 6): L_gen = L_diversity + λ_hard · L_hardness
            m_gen = self.bin_learner(x_adv)  # gradients flow to generator via x_adv
            l_diversity = self.diversity_loss_fn(m_gen)

            student_log_probs = self.student(x_adv).detach()
            # L_hardness (Eq. 5): reward generator for finding student failures
            l_hardness = -F.kl_div(
                student_log_probs, teacher_probs, reduction="batchmean"
            )

            gen_loss = (self.config.lambda_cov * l_diversity
                        + self.config.lambda_hard * l_hardness)

            self.gen_optimizer.zero_grad()
            gen_loss.backward()
            nn.utils.clip_grad_norm_(self.generator.parameters(), 5.0)
            self.gen_optimizer.step()

            # ── Student Loss (Eq. 7): minimize KL(S(x) ‖ T(x)) ───────────────
            x_adv_fresh = self._generate_batch(n_adv).detach()
            teacher_probs_fresh = self._teacher_probs(x_adv_fresh)

            # Mix with replay buffer samples
            replay = self.replay_buffer.sample(n_replay, self.device)
            if replay is not None:
                x_mix = torch.cat([x_adv_fresh, replay], dim=0)
                t_replay = self._teacher_probs(replay)
                t_mix = torch.cat([teacher_probs_fresh, t_replay], dim=0)
            else:
                x_mix, t_mix = x_adv_fresh, teacher_probs_fresh

            student_loss = F.kl_div(self.student(x_mix), t_mix, reduction="batchmean")
            self.student_optimizer.zero_grad()
            student_loss.backward()
            nn.utils.clip_grad_norm_(self.student.parameters(), 5.0)
            self.student_optimizer.step()

            # Track bin assignments for coverage measurement
            with torch.no_grad():
                hard_assignments = self.bin_learner.hard_assignment(x_adv_fresh)
                all_bin_assignments.append(hard_assignments.cpu())

            self.gen_scheduler.step()
            self.student_scheduler.step()

            if (epoch + 1) % 100 == 0:
                all_ba = torch.cat(all_bin_assignments, dim=0)
                coverage = self.diversity_loss_fn.compute_coverage(all_ba)
                print(
                    f"    Adv epoch {epoch+1}/{self.config.adv_epochs} | "
                    f"gen_loss={gen_loss.item():.4f} | "
                    f"student_kl={student_loss.item():.4f} | "
                    f"coverage={coverage:.1%}"
                )
                self.history.append({
                    "epoch": epoch + 1,
                    "gen_loss": gen_loss.item(),
                    "student_kl": student_loss.item(),
                    "coverage": coverage,
                })

    def train(self):
        """Execute full three-phase TabKD training."""
        print(f"\n{'='*60}")
        print(f"  TabKD Training | {self.config.n_features} features | {self.config.n_bins} bins")
        print(f"  Phases: Warmup({self.config.warmup_epochs}) → "
              f"Bins({self.config.bin_epochs}) → Adv({self.config.adv_epochs})")
        print(f"{'='*60}\n")
        self._warmup()
        self._bin_stabilization()
        self._adversarial_distillation()
        print("\n  TabKD training complete.")

    @torch.no_grad()
    def evaluate(self, X_test: Tensor, y_test: Tensor) -> Dict[str, float]:
        """
        Evaluate student model on held-out test data.

        Returns accuracy, F1, teacher-student agreement, and coverage.
        (Eqs. 10–13 in paper)
        """
        self.student.eval()
        student_preds = self.student.predict(X_test)
        teacher_preds = get_teacher_probs(self.teacher, X_test).argmax(dim=-1)
        y_np = y_test.cpu().numpy()
        s_np = student_preds.cpu().numpy()
        t_np = teacher_preds.cpu().numpy()

        acc = (s_np == y_np).mean()

        # F1 score (binary)
        tp = ((s_np == 1) & (y_np == 1)).sum()
        fp = ((s_np == 1) & (y_np == 0)).sum()
        fn = ((s_np == 0) & (y_np == 1)).sum()
        precision = tp / max(1, tp + fp)
        recall = tp / max(1, tp + fn)
        f1 = 2 * precision * recall / max(1e-8, precision + recall)

        # Teacher-student agreement (Eq. 12)
        agreement = (s_np == t_np).mean()

        # Coverage (Eq. 13)
        hard_assignments = self.bin_learner.hard_assignment(X_test)
        coverage = self.diversity_loss_fn.compute_coverage(hard_assignments.cpu())

        return {
            "accuracy": float(acc) * 100,
            "f1": float(f1) * 100,
            "agreement": float(agreement) * 100,
            "coverage": float(coverage) * 100,
        }


# ─── SECTION 9: Dataset Helpers & Evaluation ──────────────────────────────────

def load_benchmark_dataset(
    name: str = "adult",
    n_samples: int = 1000,
    test_ratio: float = 0.2,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """
    Load one of the four TabKD benchmark datasets (Section 5.2).

    Production: download real datasets from UCI Machine Learning Repository:
      - Adult:         https://doi.org/10.24432/C5XW20
      - Breast Cancer: https://doi.org/10.24432/C51P4M
      - Credit:        https://doi.org/10.24432/C55S3H
      - Mushroom:      https://doi.org/10.24432/C5959T

    Production loading example (requires sklearn):
        from sklearn.datasets import load_breast_cancer
        from sklearn.preprocessing import StandardScaler
        data = load_breast_cancer()
        X = StandardScaler().fit_transform(data.data)
        y = data.target

    This function generates synthetic data with the same shape
    as each real dataset for demonstration purposes.

    Returns
    -------
    X_train, X_test, y_train, y_test — all as float32 Tensors
    """
    dataset_specs = {
        "adult":         {"n_features": 14, "name": "Adult Income (48K, 14 features)"},
        "credit":        {"n_features": 23, "name": "Credit Default (30K, 23 features)"},
        "breast_cancer": {"n_features": 30, "name": "Breast Cancer (569, 30 features)"},
        "mushroom":      {"n_features": 22, "name": "Mushroom (8K, 22 features)"},
    }

    if name not in dataset_specs:
        raise ValueError(f"Unknown dataset '{name}'. Choose from: {list(dataset_specs.keys())}")

    spec = dataset_specs[name]
    n_features = spec["n_features"]

    # Try to load real datasets via sklearn if available
    if name == "breast_cancer":
        try:
            from sklearn.datasets import load_breast_cancer
            from sklearn.preprocessing import StandardScaler
            from sklearn.model_selection import train_test_split
            data = load_breast_cancer()
            X = StandardScaler().fit_transform(data.data).astype(np.float32)
            y = data.target.astype(np.int64)
            X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=test_ratio, stratify=y, random_state=42)
            print(f"  Loaded real Breast Cancer dataset: {X.shape[0]} samples, {X.shape[1]} features")
            return (torch.tensor(X_tr), torch.tensor(X_te),
                    torch.tensor(y_tr), torch.tensor(y_te))
        except ImportError:
            pass

    # Synthetic fallback
    torch.manual_seed(42)
    X = torch.randn(n_samples, n_features)
    # Synthetic labels: XOR of first two features (non-linear interaction)
    y = ((X[:, 0] > 0).long() ^ (X[:, 1] > 0).long())

    n_test = int(n_samples * test_ratio)
    n_train = n_samples - n_test
    X_train, X_test = X[:n_train], X[n_train:]
    y_train, y_test = y[:n_train], y[n_train:]

    print(f"  Loaded synthetic '{name}' data: {n_train} train / {n_test} test | {n_features} features")
    return X_train, X_test, y_train, y_test


def run_tabkd_experiment(
    dataset_name: str = "breast_cancer",
    teacher_type: str = "nn",
    n_samples: int = 800,
    warmup_epochs: int = 5,
    bin_epochs: int = 20,
    adv_epochs: int = 40,
    n_bins: int = 8,
    verbose: bool = True,
) -> Dict[str, float]:
    """
    Run a complete TabKD experiment for a dataset + teacher combination.

    Parameters
    ----------
    dataset_name  : "adult" | "credit" | "breast_cancer" | "mushroom"
    teacher_type  : "nn" | "xgboost" | "rf" | "tabtransformer"
    n_samples     : total samples (use full dataset in production)
    warmup_epochs : Stage 0 epochs (paper: 30)
    bin_epochs    : Stage 1 epochs (paper: 200)
    adv_epochs    : Stage 2 epochs (paper: 400)
    n_bins        : bins per feature (paper: 8)

    Returns
    -------
    dict with accuracy, f1, agreement, coverage metrics
    """
    dataset_spec = {
        "adult": 14, "credit": 23, "breast_cancer": 30, "mushroom": 22
    }
    n_features = dataset_spec.get(dataset_name, 14)
    device = torch.device("cpu")

    print(f"\n{'─'*60}")
    print(f"  Dataset: {dataset_name} | Teacher: {teacher_type} | Bins: {n_bins}")
    print(f"{'─'*60}")

    # Load dataset
    X_train, X_test, y_train, y_test = load_benchmark_dataset(dataset_name, n_samples)
    n_features = X_train.shape[1]

    # Build and train teacher
    if teacher_type == "nn":
        teacher = NeuralNetworkTeacher(n_features, n_classes=2).to(device)
        teacher.train_on_data(X_train, y_train, epochs=15)
        teacher.eval()
        teacher_acc = (teacher.predict(X_test) == y_test).float().mean().item() * 100
    elif teacher_type == "xgboost":
        teacher = XGBoostTeacher()
        teacher.fit(X_train.numpy(), y_train.numpy())
        teacher_acc = (teacher.predict(X_test) == y_test).float().mean().item() * 100
    elif teacher_type == "rf":
        teacher = RandomForestTeacher()
        teacher.fit(X_train.numpy(), y_train.numpy())
        teacher_acc = (teacher.predict(X_test) == y_test).float().mean().item() * 100
    elif teacher_type == "tabtransformer":
        teacher = TabTransformerTeacher(n_features, n_classes=2).to(device)
        teacher.train_on_data(X_train, y_train, epochs=15)
        teacher.eval()
        teacher_acc = (teacher.predict(X_test) == y_test).float().mean().item() * 100
    else:
        raise ValueError(f"Unknown teacher type: {teacher_type}")

    print(f"  Teacher accuracy: {teacher_acc:.1f}%")

    # Build TabKD config
    config = TabKDConfig(
        n_features=n_features,
        n_classes=2,
        n_bins=n_bins,
        warmup_epochs=warmup_epochs,
        bin_epochs=bin_epochs,
        adv_epochs=adv_epochs,
    )

    # Run TabKD training
    trainer = TabKDTrainer(config, teacher, device)
    trainer.train()

    # Evaluate
    results = trainer.evaluate(X_test.to(device), y_test.to(device))
    results["teacher_accuracy"] = teacher_acc

    print(f"\n  Results | Acc={results['accuracy']:.1f}% | F1={results['f1']:.1f}% | "
          f"Agreement={results['agreement']:.1f}% | Coverage={results['coverage']:.1f}%")
    print(f"  (Paper target: {teacher_type}+{dataset_name} → see Table 2 for expected values)")
    return results


# ─── SECTION 10: Smoke Test ────────────────────────────────────────────────────

if __name__ == "__main__":
    print("=" * 60)
    print("TabKD — Full Framework Smoke Test")
    print("=" * 60)
    torch.manual_seed(42)
    np.random.seed(42)
    device = torch.device("cpu")

    # ── 1. Soft Bin Learner ──────────────────────────────────────────────────
    print("\n[1/7] Soft Bin Learner — shape and range checks...")
    config = TabKDConfig(n_features=14, n_classes=2, n_bins=8)
    bin_learner = SoftBinLearner(config.n_features, config.n_bins).to(device)
    x_test = torch.randn(32, config.n_features)
    m = bin_learner(x_test)
    assert m.shape == (32, 14, 8), f"Expected (32,14,8), got {m.shape}"
    assert (m >= 0).all() and (m <= 1).all(), "Memberships out of [0,1]"
    assert torch.allclose(m.sum(dim=-1), torch.ones(32, 14), atol=1e-5), "Memberships must sum to 1 per feature"
    print(f"  Membership tensor shape: {tuple(m.shape)} ✓")
    print(f"  Values in [0,1], sum to 1 per feature ✓")

    # ── 2. Bin Learning Loss ─────────────────────────────────────────────────
    print("\n[2/7] Bin Learning Loss...")
    bin_loss_fn = BinLearningLoss(config)
    teacher_probs = torch.softmax(torch.randn(32, 2), dim=-1)
    loss_bin = bin_loss_fn(m, teacher_probs)
    assert torch.isfinite(loss_bin), "Bin loss is not finite"
    print(f"  L_bin = {loss_bin.item():.4f} (finite, differentiable) ✓")

    # ── 3. Interaction Diversity Loss ────────────────────────────────────────
    print("\n[3/7] Interaction Diversity Loss & Coverage...")
    div_loss_fn = InteractionDiversityLoss(config)
    loss_div = div_loss_fn(m)
    assert torch.isfinite(loss_div), "Diversity loss is not finite"
    n_pairs = config.n_features * (config.n_features - 1) // 2
    print(f"  Feature pairs: {n_pairs} | L_diversity = {loss_div.item():.4f} ✓")

    hard_assignments = bin_learner.hard_assignment(x_test)
    coverage = div_loss_fn.compute_coverage(hard_assignments.cpu())
    print(f"  Pairwise coverage on 32 samples: {coverage:.1%}")
    total_cells = n_pairs * (config.n_bins ** 2)
    print(f"  Total pairwise bin cells: {total_cells:,}")

    # ── 4. Generator & Student Forward Pass ──────────────────────────────────
    print("\n[4/7] Generator & Student Networks...")
    gen = TabularGenerator(config).to(device)
    student = TabularStudent(config).to(device)
    x_gen = gen.sample(64, device)
    assert x_gen.shape == (64, 14), f"Generator output shape mismatch: {x_gen.shape}"
    log_probs = student(x_gen)
    assert log_probs.shape == (64, 2), f"Student output shape mismatch: {log_probs.shape}"
    assert torch.allclose(log_probs.exp().sum(dim=-1), torch.ones(64), atol=1e-5)
    print(f"  Generator: noise({config.noise_dim}) → samples({config.n_features}) ✓")
    print(f"  Student: samples({config.n_features}) → log_probs({config.n_classes}) ✓")
    n_gen = sum(p.numel() for p in gen.parameters())
    n_stu = sum(p.numel() for p in student.parameters())
    print(f"  Generator params: {n_gen:,} | Student params: {n_stu:,}")

    # ── 5. Neural Network Teacher ─────────────────────────────────────────────
    print("\n[5/7] Teacher Training (Neural Network)...")
    teacher_nn = NeuralNetworkTeacher(config.n_features, config.n_classes).to(device)
    X_mock = torch.randn(200, config.n_features)
    y_mock = ((X_mock[:, 0] > 0).long() ^ (X_mock[:, 1] > 0).long())
    loss_final = teacher_nn.train_on_data(X_mock, y_mock, epochs=10)
    teacher_nn.eval()
    teacher_probs_mock = get_teacher_probs(teacher_nn, X_mock[:8])
    assert teacher_probs_mock.shape == (8, 2)
    assert torch.allclose(teacher_probs_mock.sum(dim=-1), torch.ones(8), atol=1e-5)
    print(f"  Teacher trained | final loss: {loss_final:.4f}")
    print(f"  Teacher probs shape: {tuple(teacher_probs_mock.shape)} | sum to 1 ✓")

    # ── 6. Replay Buffer ─────────────────────────────────────────────────────
    print("\n[6/7] Replay Buffer...")
    buf = ReplayBuffer(max_size=500)
    for _ in range(10):
        buf.add(torch.randn(20, config.n_features))
    sample = buf.sample(32, device)
    assert sample is not None and sample.shape[1] == config.n_features
    print(f"  Buffer size: {len(buf)} | Sample shape: {tuple(sample.shape)} ✓")

    # ── 7. Full Training Run (short) ──────────────────────────────────────────
    print("\n[7/7] Full TabKD Training Run (Breast Cancer, NN teacher, mini-epochs)...")
    results = run_tabkd_experiment(
        dataset_name="breast_cancer",
        teacher_type="nn",
        n_samples=400,
        warmup_epochs=3,
        bin_epochs=10,
        adv_epochs=20,
        n_bins=8,
    )

    print("\n" + "=" * 60)
    print("✓  All TabKD checks passed. Framework is ready for use.")
    print("=" * 60)
    print("""
Next steps to reproduce paper results:
  1. Download real datasets from UCI Repository:
       Adult:         https://doi.org/10.24432/C5XW20
       Breast Cancer: https://doi.org/10.24432/C51P4M
       Credit:        https://doi.org/10.24432/C55S3H
       Mushroom:      https://doi.org/10.24432/C5959T

  2. Install full dependencies:
       pip install scikit-learn xgboost tab-transformer-pytorch

  3. Run full training with paper hyperparameters:
       config = TabKDConfig(
           n_bins=8, batch_size=128, lr=0.001,
           warmup_epochs=30, bin_epochs=200, adv_epochs=400,
           lambda_cov=10.0, lambda_hard=2.0,
       )

  4. Apply teacher-specific temperature schedules (Table 1):
       NN:              tau_start=1.0, tau_end=0.05
       Random Forest:   tau_start=1.2, tau_end=0.08
       XGBoost:         tau_start=1.5, tau_end=0.10
       TabTransformer:  tau_start=1.2, tau_end=0.08

  5. Access open-source implementation:
       https://anonymous.4open.science/r/int_div-0413/README.md

  6. Expected results (agreement %):
       Breast Cancer + NN:       ~95.6%
       Breast Cancer + XGBoost:  ~96.5%
       Credit + NN:              ~97.0%
       Mushroom + NN:            ~96.5%
""")

Read the Full Paper & Access the Code

The complete study — including all 16 dataset-teacher result tables, coverage-vs-agreement correlation plots across all teacher types, ablation tables, and the open-source implementation — is available on arXiv and the anonymous code repository.

Academic Citation:
Pereira, S. N., Khadka, K., & Lei, Y. (2026). TabKD: Tabular Knowledge Distillation through Interaction Diversity of Learned Feature Bins. arXiv:2603.15481v1 [cs.LG]. University of Texas at Arlington.

This article is an independent editorial analysis of peer-reviewed research. The Python implementation is an educational adaptation covering all described algorithmic components. For exact replication, use the official repository with real benchmark datasets downloaded from the UCI Machine Learning Repository.

Leave a Comment

Your email address will not be published. Required fields are marked *

Follow by Email
Tiktok