TabKD: What Happens When You Teach a Tiny Model to Think Like XGBoost — Without Seeing Any Real Data
Researchers at UT Arlington built a data-free knowledge distillation framework for tabular models that borrows a principle from software testing — systematically covering all pairs of feature interactions — and achieves the best student-teacher agreement in 14 of 16 benchmark configurations across neural networks, XGBoost, Random Forest, and TabTransformer teachers, outperforming five state-of-the-art baselines.
A hospital trains a cancer-risk model on tens of thousands of patient records. That model works brilliantly, but patient privacy laws mean the training data can never leave the hospital’s servers. The model itself is too large for the mobile diagnostic device the hospital wants to deploy. How do you compress the model when you cannot access the data it was trained on? This is the data-free knowledge distillation problem — and for tabular data, it has been quietly unsolved. TabKD solves it by asking a question borrowed from software engineering: instead of generating random synthetic samples and hoping they cover the important cases, what if we could guarantee that every meaningful combination of feature behaviors gets tested?
Why Tabular Data Is Genuinely Different
When researchers talk about knowledge distillation, they almost always mean image models. And the methods that work for images — adversarial generators, batch normalization inversion, spatial activation matching — rely on assumptions that simply do not hold for tabular data.
The first assumption is that the teacher is differentiable. You can backpropagate through a neural network to generate inputs that fool it. You cannot backpropagate through XGBoost or Random Forest. These are the most widely used models in tabular domains precisely because their ensemble structure makes them powerful — but it also makes them opaque to gradient-based generation.
The second assumption is that important features interact locally and hierarchically. Image models exploit spatial locality: a pixel matters primarily in relation to its neighbors. Tabular models work completely differently. A credit risk model doesn’t look at age and debt-to-income ratio as separate signals — it learns that the combination (age > 50 AND debt-to-income < 0.3) predicts low risk, while neither condition alone is sufficient. These are sharp, non-linear interactions across features that may have nothing structurally in common. There is no spatial neighborhood. There is no locality prior.
The third assumption is that random or entropy-guided generation will eventually explore the parts of feature space that matter. For tabular data with many features, this is overoptimistic. The combinatorial explosion of possible feature combinations means that random sampling reliably misses specific pairwise interactions, especially rare ones, leaving the student model permanently blind to the decision rules that govern them.
Existing data-free distillation methods exhibit mode collapse on tabular data — the generator fixates on a small region of the feature space and never explores critical decision rules encoded in specific feature combinations. The student ends up mimicking the teacher’s average behavior but missing the precise interaction patterns that make the teacher accurate on hard cases.
The Insight From Software Testing
The breakthrough in TabKD comes from an unexpected direction: combinatorial testing, a methodology from software engineering developed to find bugs in complex systems efficiently.
The key observation behind combinatorial testing is this: software bugs are almost never triggered by all parameters being in unusual states simultaneously. In practice, nearly all known bugs are triggered by interactions among just two or three parameters. This means you do not need to test every possible combination of all parameters — you just need to guarantee that every pair of parameter values (2-way coverage) gets exercised at least once. This requires dramatically fewer test cases than exhaustive testing, while still catching the vast majority of real bugs.
TabKD applies exactly this logic to synthetic data generation for tabular model distillation. The authors argue — and demonstrate empirically — that tabular model decisions similarly depend on interactions among small feature subsets. The credit risk rule (age > 50 AND debt-to-income < 0.3 → low risk) is a pairwise interaction. So is the medical rule (blood pressure > threshold AND cholesterol > threshold → at-risk). Covering all pairwise feature combinations in the synthetic training data is not just a nice-to-have — it is a sufficient condition for the student to observe all the decision patterns the teacher uses.
But to cover pairwise feature interactions, you first need to discretize each feature into a finite set of meaningful regions — bins. Random uniform bins waste coverage: if a feature takes values from 1 to 100 and the teacher’s decision boundary lies at 47, then a uniform bin covering [40, 60] lumps together values with completely different predictions. You need bins aligned with where the teacher’s behavior actually changes.
The TabKD Framework: Three Stages
TABKD FRAMEWORK — THREE-STAGE TRAINING
════════════════════════════════════════════════════════════════
STAGE 0 — WARMUP (30 epochs)
Random noise z ~ N(0, I)
│
Generator G(z) → synthetic samples x_gen
│
Teacher T(x_gen) → soft labels (ground truth)
│
Student S(x_gen) → minimize KL divergence
│
Populate replay buffer (90% used later as stability anchors)
│
Goal: Give student a reasonable starting point before
adversarial training begins.
────────────────────────────────────────────────────────────────
STAGE 1 — BIN LEARNING (200 epochs, linear temp annealing)
Soft Bin Membership m_k^(i)(x):
Each feature i is partitioned into K=8 adaptive bins.
Membership is SOFT (differentiable, not hard assignment).
│
Bin Loss (Eq. 1):
L_bin = λ_intra · Var_intra(M, P_T) ← minimize: same predictions inside bins
+ λ_inter · 1/Var_inter(M, P_T) ← maximize: different predictions across bins
│
Boundary-Focused Generator Loss (Eq. 2):
L_gen^(1) = λ_div · L_class-div(S(x_gen)) ← class balance
+ λ_boundary · L_entropy(T(x_gen)) ← target high-uncertainty regions
│
Goal: Bins stabilize at teacher decision boundaries.
Bins are FROZEN before Stage 2 begins.
⚠ Staged approach critical: jointly optimizing bins and
coverage causes instability (generator fills bins before
boundaries stabilize).
────────────────────────────────────────────────────────────────
STAGE 2 — ADVERSARIAL DISTILLATION (400 epochs)
With frozen bins, the interaction space is now finite:
F*(F-1)/2 feature pairs × K² bin combinations per pair.
Empirical joint bin distribution (Eq. 3):
P(k1, k2 | i, j) = (1/N) Σ_x m_k1^(i)(x) · m_k2^(j)(x)
Diversity Loss — maximize entropy = uniform coverage (Eq. 4):
L_diversity = −(1/C(F,2)) Σ_{i
Stage 1: Dynamic Bin Learning
The bin learner assigns each input sample a soft membership vector over K bins for each feature. Soft membership is critical — hard discrete assignment would break differentiability and make the bin boundaries untrainable. The training objective pulls bin boundaries toward the teacher's decision boundaries by minimizing prediction variance within bins (samples in the same bin should get similar teacher predictions) and maximizing prediction variance between bins (samples in different bins should get different teacher predictions).
During bin learning, a separate boundary-focused generator helps by producing samples near the teacher's decision boundaries — exactly where bin assignments are most informative. Entropy maximization over teacher predictions targets high-uncertainty regions where predictions are approximately 0.5, which correspond to the boundaries the bins need to capture.
Stage 2: Interaction Diversity Loss
With bins frozen, the interaction space becomes finite and enumerable. For F features and K bins each, there are F*(F-1)/2 unique feature pairs, and each pair has K² possible bin combinations. Full 2-way coverage means the generator must produce at least one sample in each of these K² cells for every feature pair.
The coverage metric is the empirical joint distribution over bin combinations for each feature pair. Maximum entropy over this distribution corresponds to uniform coverage. TabKD's diversity loss minimizes negative entropy — equivalently maximizes entropy — over all pairwise joint distributions:
This is elegant in its simplicity. The generator does not need to enumerate all possible bin combinations explicitly — it just needs to produce samples whose bin assignments form a uniform distribution across all feature pairs. The entropy loss naturally pushes the generator toward underexplored combinations.
Hardness + Stability
Diversity alone is not enough. A generator that covers all bin combinations uniformly but avoids regions where the student currently disagrees with the teacher is wasteful — it generates samples the student already handles correctly. The hardness loss adds a complementary objective: reward the generator for finding samples where the student fails. KL divergence between teacher and student outputs (negated, so the generator maximizes it) naturally concentrates samples near decision boundaries where student errors are most costly.
The student, meanwhile, learns from a mix of 90% adversarial samples (where it is most likely to be wrong) and 10% replay buffer samples collected during warmup. The replay buffer prevents catastrophic forgetting — without it, the student would optimize perfectly for edge cases while degrading on the ordinary inputs it learned during warmup.
The three-stage training is not just an engineering convenience — it prevents a specific failure mode. If you jointly optimize bins and coverage from the start, the generator fills bins faster than bin boundaries can stabilize. By the time the generator is producing diverse samples, the bins are misaligned and "diverse" means nothing. Freezing bins before adversarial training begins ensures that coverage statistics reflect genuine semantic diversity in the teacher's prediction landscape.
Experimental Setup and Benchmarks
TabKD is evaluated across four binary classification datasets: Adult Income (48K samples, 14 features — predicting whether income exceeds $50K), Credit Card Default (30K samples, 23 features — predicting payment default), Breast Cancer (569 samples, 30 features — malignant vs. benign), and Mushroom (8K samples, 22 features — edible vs. poisonous).
Four teacher architectures are evaluated: a neural network (two hidden layers, 128→64 units), XGBoost (100 estimators, depth 6), Random Forest (100 trees, depth 10), and TabTransformer (multi-head self-attention over tabular features). This gives 16 dataset-teacher combinations. All teachers use a query budget of 9,600 synthetic samples, following TabExtractor's evaluation protocol.
The student is always a single-hidden-layer network with 32 units — a deliberately lightweight model that makes the distillation challenge harder. The same student architecture is used regardless of the teacher's complexity, making the results clean and fair.
Results: 14 of 16 Wins
Headline Numbers: Breast Cancer and Mushroom
| Method | Breast Cancer NN Agree | Breast Cancer XGB Agree | Mushroom NN Agree | Mushroom RF Agree |
|---|---|---|---|---|
| StealML | 75.4 | 88.6 | 94.0 | 85.5 |
| TabExtractor | 85.1 | 86.1 | 90.5 | 79.0 |
| CF | 86.0 | 79.2 | 63.0 | 65.5 |
| DualCF | 83.3 | 72.8 | 90.0 | 88.5 |
| DivT | 87.7 | 89.5 | 93.0 | 88.0 |
| TabKD (Ours) | 95.6 | 96.5 | 96.5 | 91.1 |
Selected agreement accuracy (%) results. TabKD's margins on Breast Cancer are particularly large — up to 20 percentage points over DivT with Neural Network teacher.
The Full Picture: All Datasets and Teachers
| Dataset | Teacher | Best Baseline Agree | TabKD Agree | Δ |
|---|---|---|---|---|
| Adult | NN | 91.4 (DivT) | 91.0 | −0.4 (DivT wins) |
| Adult | XGBoost | 78.6 | 81.1 | +2.5 |
| Adult | RF | 69.1 | 84.4 | +15.3 |
| Adult | TabTransformer | 85.0 | 86.1 | +1.1 |
| Credit | NN | 90.9 | 97.0 | +6.1 |
| Credit | XGBoost | 85.8 | 88.0 | +2.2 |
| Credit | RF | 87.5 | 87.7 | +0.2 |
| Credit | TabTransformer | 81.4 | 87.1 | +5.7 |
| Breast Cancer | NN | 87.7 | 95.6 | +7.9 |
| Breast Cancer | XGBoost | 89.5 | 96.5 | +7.0 |
| Breast Cancer | RF | 83.5 | 91.2 | +7.7 |
| Breast Cancer | TabTransformer | 89.1 | 90.4 | +1.3 |
| Mushroom | NN | 94.0 | 96.5 | +2.5 |
| Mushroom | XGBoost | 81.5 | 83.5 | +2.0 |
| Mushroom | RF | 88.5 | 91.1 | +2.6 |
| Mushroom | TabTransformer | 88.3 (DivT) | 85.2 | −3.1 (DivT wins) |
Table: TabKD vs. best baseline agreement accuracy. TabKD leads in 14/16 configurations. The two exceptions are Adult+NN (marginal −0.4%) and Mushroom+TabTransformer (−3.1%, where DivT leads).
The two cases where TabKD does not lead are instructive. Adult with Neural Network teacher: DivT achieves 91.4% vs. TabKD's 91.0% — a difference of 0.4 percentage points, well within noise. Mushroom with TabTransformer: DivT outperforms at 88.3% vs. 85.2%. The authors note that Mushroom's features are largely categorical with clear rules, where even uniform static bins work well. The adaptive bin learning provides less marginal benefit when the feature space is already naturally discrete.
"Interaction coverage strongly correlates with distillation quality, validating our core hypothesis that systematic coverage of feature combinations is essential for effective tabular distillation." — Pereira, Khadka, and Lei, arXiv:2603.15481 (2026)
The Coverage-Agreement Correlation
The most important result in the paper is not a single accuracy number — it is the relationship between interaction coverage and student-teacher agreement, shown across all four teacher types. As coverage increases from ~20% to ~80% of all pairwise bin combinations, agreement accuracy rises monotonically for every teacher architecture. The relationship is consistent and strong: more coverage means better distillation, regardless of what kind of model you are trying to distill.
This validates TabKD's core premise. It is not merely a clever trick for getting slightly better numbers — it is a demonstration that the reason prior methods fail is their inability to cover the feature interaction space, and that fixing this problem directly improves distillation quality in a measurable, predictable way.
Ablation: Why Dynamic Bins Matter
| Teacher | Adult (Dynamic) | Adult (Static) | Credit (Dynamic) | Credit (Static) |
|---|---|---|---|---|
| Neural Network | 92.0 | 87.0 | 97.0 | 91.0 |
| XGBoost | 81.1 | 42.0 | 88.5 | 88.0 |
| Random Forest | 84.4 | 58.9 | 88.2 | 88.2 |
| TabTransformer | 86.1 | 33.0 | 87.1 | 86.6 |
Table 3: Dynamic vs. static binning ablation. The Adult dataset shows catastrophic degradation with static bins — XGBoost drops from 81.1% to 42.0%, TabTransformer from 86.1% to 33.0%.
The ablation study on static bins delivers the clearest demonstration of why adaptive bin learning matters. On the Adult dataset — which is complex, imbalanced, and has continuous features where decision boundaries are not obvious — replacing dynamic bins with uniform equal-width bins is catastrophic. XGBoost agreement drops from 81.1% to 42.0%. TabTransformer drops from 86.1% to 33.0%. Without bins aligned to teacher decision boundaries, the coverage metric is measuring the wrong thing: "uniform coverage" across meaningless uniform bins guarantees nothing about meaningful feature interaction coverage.
On simpler datasets like Mushroom, where many features are categorical and their natural values already act as meaningful bins, dynamic and static bins perform similarly. This makes intuitive sense and gives confidence that the dynamic bin learning is doing real work rather than overfitting to a quirk of the evaluation.
For practitioners considering TabKD: the dynamic bin learning stage is not optional on continuous-feature datasets. Static bins can cause catastrophic failure. The 200-epoch bin stabilization phase is the most critical investment in the training pipeline — once bins are well-aligned, the adversarial distillation phase reliably produces strong results across all teacher types including non-differentiable gradient-free ensembles like XGBoost and Random Forest.
Complete End-to-End TabKD Implementation (PyTorch)
The implementation below is a complete, runnable PyTorch implementation of TabKD, structured across 10 sections that map directly to the paper. It covers dynamic soft bin learning with boundary-focused sampling, the interaction diversity loss with entropy maximization over pairwise joint bin distributions, the hardness loss via KL divergence, the three-phase training loop (warmup → bin stabilization → adversarial distillation), replay buffer management, all four teacher wrappers (Neural Network, XGBoost, Random Forest, TabTransformer), dataset helpers for all four benchmark datasets, evaluation metrics including coverage quantification, and a smoke test that validates all components without requiring real data.
# ==============================================================================
# TabKD: Tabular Knowledge Distillation through Interaction Diversity
# of Learned Feature Bins
# Paper: arXiv:2603.15481v1 [cs.LG] (2026)
# Authors: Shovon Niverd Pereira, Krishna Khadka, Yu Lei
# Affiliation: University of Texas at Arlington
# ==============================================================================
# Sections:
# 1. Imports & Configuration
# 2. Dynamic Soft Bin Learner (decision-boundary-aligned discretization)
# 3. Interaction Diversity Loss (pairwise bin entropy maximization)
# 4. Generator Network (noise → synthetic tabular samples)
# 5. Student Network (lightweight MLP approximation)
# 6. Teacher Wrappers (NN, XGBoost, Random Forest, TabTransformer)
# 7. Replay Buffer (catastrophic forgetting prevention)
# 8. TabKD Trainer (three-phase: warmup → bins → adversarial)
# 9. Dataset Helpers & Evaluation Metrics
# 10. Smoke Test
# ==============================================================================
from __future__ import annotations
import math
import random
import warnings
from collections import deque
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset
warnings.filterwarnings("ignore")
# ─── SECTION 1: Configuration ─────────────────────────────────────────────────
@dataclass
class TabKDConfig:
"""
Hyperparameter configuration for TabKD (Section 5.7).
Attributes
----------
n_features : int — number of input features (F)
n_classes : int — number of output classes (C)
n_bins : int — bins per feature K (paper: 8)
noise_dim : int — generator input noise dimension
gen_hidden : int — generator hidden layer size
student_hidden : int — student MLP hidden units (paper: 32)
batch_size : int — training batch size (paper: 128)
lr : float — Adam learning rate (paper: 0.001)
lambda_cov : float — diversity loss weight (paper: 10.0)
lambda_hard : float — hardness loss weight (paper: 2.0)
lambda_intra : float — intra-bin variance weight
lambda_inter : float — inter-bin inverse-variance weight
lambda_div : float — class diversity weight (bin phase generator)
lambda_boundary : float — boundary entropy weight (bin phase generator)
warmup_epochs : int — Stage 0 warmup epochs (paper: 30)
bin_epochs : int — Stage 1 bin learning epochs (paper: 200)
adv_epochs : int — Stage 2 adversarial epochs (paper: 400)
tau_start : float — temperature annealing start
tau_end : float — temperature annealing end
replay_ratio : float — fraction of replay samples in adversarial batch (paper: 0.1)
replay_size : int — maximum replay buffer size
query_budget : int — total synthetic query budget (paper: 9600)
"""
n_features: int = 14
n_classes: int = 2
n_bins: int = 8
noise_dim: int = 64
gen_hidden: int = 128
student_hidden: int = 32
batch_size: int = 128
lr: float = 1e-3
lambda_cov: float = 10.0
lambda_hard: float = 2.0
lambda_intra: float = 1.0
lambda_inter: float = 1.0
lambda_div: float = 1.0
lambda_boundary: float = 1.0
warmup_epochs: int = 30
bin_epochs: int = 200
adv_epochs: int = 400
tau_start: float = 1.0
tau_end: float = 0.05
replay_ratio: float = 0.1
replay_size: int = 2000
query_budget: int = 9600
# ─── SECTION 2: Dynamic Soft Bin Learner ──────────────────────────────────────
class SoftBinLearner(nn.Module):
"""
Learns K adaptive bin boundaries per feature, aligned with teacher
decision boundaries (Section 4.2).
Instead of hard discrete assignments (which are non-differentiable),
uses a temperature-annealed softmax over learned bin center offsets
to produce soft membership vectors m ∈ [0,1]^K for each feature.
Bin Learning Objective (Eq. 1):
L_bin = λ_intra · Var_intra(M, P_T)
+ λ_inter · (1 / Var_inter(M, P_T))
Intra-bin variance: for each bin, the variance of teacher predictions
weighted by soft memberships. Low = homogeneous bins (desired).
Inter-bin variance: variance across bin-mean predictions.
High = bins capture different behaviors (desired).
Parameters
----------
n_features : int — number of input features
n_bins : int — number of bins per feature K
tau : float — initial temperature for soft assignment
"""
def __init__(self, n_features: int, n_bins: int, tau: float = 1.0):
super().__init__()
self.n_features = n_features
self.n_bins = n_bins
self.tau = tau
# Learnable bin centers: (F, K)
# Initialized to spread uniformly in [-2, 2] (standardized feature space)
centers_init = torch.linspace(-2.0, 2.0, n_bins).unsqueeze(0).expand(n_features, -1).clone()
self.bin_centers = nn.Parameter(centers_init)
def set_temperature(self, tau: float):
"""Update soft assignment temperature (called during annealing)."""
self.tau = tau
def forward(self, x: Tensor) -> Tensor:
"""
Compute soft bin membership for each sample.
Parameters
----------
x : (N, F) — input samples (standardized)
Returns
-------
m : (N, F, K) — soft membership vectors
m[n, f, k] = probability that sample n, feature f belongs to bin k
"""
N = x.shape[0]
# x: (N, F) → (N, F, 1)
x_expanded = x.unsqueeze(-1)
# bin_centers: (F, K) → (1, F, K)
centers = self.bin_centers.unsqueeze(0)
# Negative squared distance to each bin center: (N, F, K)
neg_dist = -((x_expanded - centers) ** 2) / (self.tau + 1e-8)
# Soft membership via temperature-scaled softmax over bins
m = F.softmax(neg_dist, dim=-1) # (N, F, K)
return m
def hard_assignment(self, x: Tensor) -> Tensor:
"""
Hard bin assignment for coverage measurement at evaluation time.
Returns (N, F) integer bin indices.
"""
with torch.no_grad():
m = self.forward(x)
return m.argmax(dim=-1) # (N, F)
class BinLearningLoss(nn.Module):
"""
Computes the bin learning objective (Eq. 1 in paper):
L_bin = λ_intra · Var_intra(M, P_T) + λ_inter · (1 / Var_inter(M, P_T))
Intra-bin variance:
For each feature f and bin k, compute the variance of teacher
predictions P_T weighted by soft memberships m_k:
var_intra_fk = Σ_n m_nfk · (P_T_n − μ_fk)²
where μ_fk = Σ_n m_nfk · P_T_n / Σ_n m_nfk
Average over all features and bins.
Inter-bin variance:
Variance of bin-mean predictions across bins (per feature).
We want bins to have different mean predictions → maximize this.
"""
def __init__(self, config: TabKDConfig):
super().__init__()
self.config = config
def forward(self, m: Tensor, teacher_probs: Tensor) -> Tensor:
"""
Parameters
----------
m : (N, F, K) — soft bin memberships
teacher_probs: (N, C) — teacher output probabilities
Returns
-------
loss : scalar — bin learning loss
"""
N, F, K = m.shape
# Use the max teacher probability as the scalar prediction signal
p = teacher_probs.max(dim=-1).values # (N,)
# Expand p for broadcasting: (N, 1, 1)
p_exp = p.unsqueeze(-1).unsqueeze(-1)
# Weighted sum of memberships: (F, K)
m_sum = m.sum(dim=0) + 1e-8 # (F, K)
# Bin means μ_fk = Σ_n m_nfk · p_n / Σ_n m_nfk : (F, K)
# m: (N, F, K), p: (N,)
bin_means = (m * p.view(N, 1, 1)).sum(dim=0) / m_sum # (F, K)
# Intra-bin variance: (N, F, K) → scalar
diff_sq = (p.view(N, 1, 1) - bin_means.unsqueeze(0)) ** 2
intra_var = (m * diff_sq).sum(dim=0) / m_sum # (F, K)
loss_intra = intra_var.mean()
# Inter-bin variance: variance of bin_means across bins, per feature
inter_var = bin_means.var(dim=-1) + 1e-8 # (F,)
loss_inter = (1.0 / inter_var).mean()
loss = self.config.lambda_intra * loss_intra + self.config.lambda_inter * loss_inter
return loss
# ─── SECTION 3: Interaction Diversity Loss ────────────────────────────────────
class InteractionDiversityLoss(nn.Module):
"""
Core TabKD contribution: maximizes entropy over pairwise bin-combination
distributions to achieve systematic feature interaction coverage (Eq. 3–4).
For each pair of features (i, j), computes the empirical joint distribution
P(k1, k2 | i, j) over K² bin combinations (Eq. 3):
P(k1, k2 | i, j) = (1/N) Σ_x m_k1^(i)(x) · m_k2^(j)(x)
Diversity loss = negative average entropy over all feature pairs (Eq. 4):
L_diversity = −(1/C(F,2)) Σ_{i
def __init__(self, config: TabKDConfig):
super().__init__()
self.config = config
F = config.n_features
self.n_pairs = F * (F - 1) // 2
# Precompute all unique feature pair indices
self.pairs = [(i, j) for i in range(F) for j in range(i + 1, F)]
def forward(self, m: Tensor) -> Tensor:
"""
Parameters
----------
m : (N, F, K) — soft bin memberships from the bin learner
Returns
-------
loss : scalar — negative average entropy (minimize = maximize diversity)
"""
total_entropy = 0.0
count = 0
for i, j in self.pairs:
m_i = m[:, i, :] # (N, K)
m_j = m[:, j, :] # (N, K)
# Joint distribution over K² bin combinations (Eq. 3)
# outer product per sample: (N, K, K)
joint = m_i.unsqueeze(-1) * m_j.unsqueeze(-2) # (N, K, K)
# Average over batch: empirical joint distribution (K, K)
p_joint = joint.mean(dim=0) # (K, K)
# Flatten to (K²,) and compute entropy
p_flat = p_joint.reshape(-1) + 1e-10 # numerical stability
p_flat = p_flat / p_flat.sum() # normalize to valid distribution
entropy = -(p_flat * p_flat.log()).sum()
total_entropy = total_entropy + entropy
count += 1
# Negative entropy: minimizing this maximizes entropy (uniform coverage)
loss = -(total_entropy / max(1, count))
return loss
@torch.no_grad()
def compute_coverage(self, bin_assignments: Tensor) -> float:
"""
Compute cumulative interaction coverage (Eq. 13) given hard bin assignments.
Parameters
----------
bin_assignments : (N, F) — integer bin indices (from SoftBinLearner.hard_assignment)
Returns
-------
coverage : float in [0, 1] — fraction of pairwise bin cells visited
"""
F = self.config.n_features
K = self.config.n_bins
visited = set()
N = bin_assignments.shape[0]
for n in range(N):
row = bin_assignments[n]
for i, j in self.pairs:
k1 = row[i].item()
k2 = row[j].item()
visited.add((i, j, k1, k2))
total_cells = len(self.pairs) * (K * K)
return len(visited) / max(1, total_cells)
# ─── SECTION 4: Generator Network ─────────────────────────────────────────────
class TabularGenerator(nn.Module):
"""
Generator G(z): maps noise z ~ N(0, I) to synthetic tabular samples.
Architecture: 3-layer MLP with Batch Normalization and LeakyReLU activations.
Output: F-dimensional synthetic samples in the standardized feature space.
The generator is trained to simultaneously:
1. Maximize interaction diversity (L_diversity) — cover all bin pairs
2. Maximize student-teacher disagreement (L_hardness) — target weaknesses
Parameters
----------
config : TabKDConfig
"""
def __init__(self, config: TabKDConfig):
super().__init__()
self.config = config
H = config.gen_hidden
F = config.n_features
self.net = nn.Sequential(
nn.Linear(config.noise_dim, H),
nn.BatchNorm1d(H),
nn.LeakyReLU(0.2),
nn.Linear(H, H * 2),
nn.BatchNorm1d(H * 2),
nn.LeakyReLU(0.2),
nn.Linear(H * 2, H),
nn.BatchNorm1d(H),
nn.LeakyReLU(0.2),
nn.Linear(H, F),
nn.Tanh(), # Output in [-1, 1] ≈ standardized feature range
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, z: Tensor) -> Tensor:
"""
Parameters
----------
z : (N, noise_dim) — Gaussian noise
Returns
-------
x_gen : (N, F) — synthetic tabular samples
"""
return self.net(z)
def sample(self, n: int, device: torch.device) -> Tensor:
"""Generate n synthetic samples."""
z = torch.randn(n, self.config.noise_dim, device=device)
return self.forward(z)
# ─── SECTION 5: Student Network ───────────────────────────────────────────────
class TabularStudent(nn.Module):
"""
Student S(x): single-hidden-layer MLP (32 units, ReLU) as per Section 5.3.
Deliberately lightweight — the same architecture is used regardless of
teacher type, making results fair and comparable across configurations.
Training objective (Eq. 7):
L_student = (1/N) Σ_x D_KL(S(x) ‖ T(x))
where T(x) provides soft label targets from the teacher.
Parameters
----------
config : TabKDConfig
"""
def __init__(self, config: TabKDConfig):
super().__init__()
self.config = config
self.net = nn.Sequential(
nn.Linear(config.n_features, config.student_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(config.student_hidden, config.n_classes),
)
def forward(self, x: Tensor) -> Tensor:
"""Returns log-probabilities via log_softmax."""
return F.log_softmax(self.net(x), dim=-1)
def predict(self, x: Tensor) -> Tensor:
"""Returns class predictions."""
with torch.no_grad():
return self.forward(x).argmax(dim=-1)
# ─── SECTION 6: Teacher Wrappers ──────────────────────────────────────────────
class NeuralNetworkTeacher(nn.Module):
"""
Pre-trained neural network teacher (128 → 64 units, ReLU, dropout 0.2).
As per Section 5.3.
"""
def __init__(self, n_features: int, n_classes: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_features, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, n_classes),
)
def forward(self, x: Tensor) -> Tensor:
"""Returns class probabilities (softmax)."""
logits = self.net(x)
return F.softmax(logits, dim=-1)
def predict(self, x: Tensor) -> Tensor:
with torch.no_grad():
return self.forward(x).argmax(dim=-1)
def train_on_data(
self,
X_train: Tensor,
y_train: Tensor,
epochs: int = 30,
lr: float = 0.001,
) -> float:
"""Quick teacher pre-training. Returns final loss."""
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
dataset = TensorDataset(X_train, y_train)
loader = DataLoader(dataset, batch_size=64, shuffle=True)
final_loss = 0.0
for epoch in range(epochs):
for xb, yb in loader:
optimizer.zero_grad()
loss = criterion(self.net(xb), yb)
loss.backward()
optimizer.step()
final_loss = loss.item()
return final_loss
class XGBoostTeacher:
"""
XGBoost teacher wrapper. Non-differentiable gradient-free ensemble.
TabKD handles this via black-box querying (no backprop through teacher).
Requires: pip install xgboost
Falls back to mock implementation for smoke testing.
"""
def __init__(self, n_estimators: int = 100, max_depth: int = 6, lr: float = 0.1):
self.params = {"n_estimators": n_estimators, "max_depth": max_depth, "learning_rate": lr}
self._model = None
self._mock = False
def fit(self, X: np.ndarray, y: np.ndarray):
try:
from xgboost import XGBClassifier
self._model = XGBClassifier(**self.params, use_label_encoder=False, eval_metric="logloss")
self._model.fit(X, y)
except ImportError:
print(" [XGBoostTeacher] xgboost not installed. Using mock teacher.")
self._mock = True
def predict_proba_tensor(self, x: Tensor) -> Tensor:
"""Returns teacher probabilities as a Tensor (N, C)."""
if self._mock or self._model is None:
N = x.shape[0]
probs = torch.rand(N, 2)
return probs / probs.sum(dim=-1, keepdim=True)
X_np = x.detach().cpu().numpy()
probs = self._model.predict_proba(X_np)
return torch.tensor(probs, dtype=torch.float32, device=x.device)
def predict(self, x: Tensor) -> Tensor:
return self.predict_proba_tensor(x).argmax(dim=-1)
class RandomForestTeacher:
"""
Random Forest teacher wrapper. Non-differentiable ensemble.
Requires: pip install scikit-learn
Falls back to mock for smoke testing.
"""
def __init__(self, n_estimators: int = 100, max_depth: int = 10, min_samples_split: int = 5):
self.params = {"n_estimators": n_estimators, "max_depth": max_depth,
"min_samples_split": min_samples_split}
self._model = None
self._mock = False
def fit(self, X: np.ndarray, y: np.ndarray):
try:
from sklearn.ensemble import RandomForestClassifier
self._model = RandomForestClassifier(**self.params, random_state=42)
self._model.fit(X, y)
except ImportError:
print(" [RandomForestTeacher] sklearn not installed. Using mock teacher.")
self._mock = True
def predict_proba_tensor(self, x: Tensor) -> Tensor:
if self._mock or self._model is None:
N = x.shape[0]
probs = torch.rand(N, 2)
return probs / probs.sum(dim=-1, keepdim=True)
X_np = x.detach().cpu().numpy()
probs = self._model.predict_proba(X_np)
return torch.tensor(probs, dtype=torch.float32, device=x.device)
def predict(self, x: Tensor) -> Tensor:
return self.predict_proba_tensor(x).argmax(dim=-1)
class TabTransformerTeacher(nn.Module):
"""
Lightweight TabTransformer teacher (Section 5.3).
Multi-head self-attention over feature embeddings.
Production: use the full TabTransformer implementation:
pip install tab-transformer-pytorch
"""
def __init__(self, n_features: int, n_classes: int, d_model: int = 32, n_heads: int = 4, n_layers: int = 2):
super().__init__()
self.feature_embed = nn.Linear(1, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True, dropout=0.1)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.head = nn.Linear(d_model * n_features, n_classes)
self.n_features = n_features
def forward(self, x: Tensor) -> Tensor:
# x: (N, F) → (N, F, 1) → embed → (N, F, d_model)
x_emb = self.feature_embed(x.unsqueeze(-1))
x_out = self.transformer(x_emb)
x_flat = x_out.reshape(x.shape[0], -1)
return F.softmax(self.head(x_flat), dim=-1)
def predict(self, x: Tensor) -> Tensor:
with torch.no_grad():
return self.forward(x).argmax(dim=-1)
def train_on_data(self, X_train: Tensor, y_train: Tensor, epochs: int = 20) -> float:
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
dataset = TensorDataset(X_train, y_train)
loader = DataLoader(dataset, batch_size=64, shuffle=True)
final_loss = 0.0
for _ in range(epochs):
for xb, yb in loader:
optimizer.zero_grad()
logits = self.head(
self.transformer(self.feature_embed(xb.unsqueeze(-1))).reshape(xb.shape[0], -1)
)
loss = F.cross_entropy(logits, yb)
loss.backward()
optimizer.step()
final_loss = loss.item()
return final_loss
def get_teacher_probs(teacher, x: Tensor) -> Tensor:
"""Unified interface to get teacher soft label probabilities."""
if isinstance(teacher, (NeuralNetworkTeacher, TabTransformerTeacher)):
with torch.no_grad():
return teacher(x)
elif isinstance(teacher, (XGBoostTeacher, RandomForestTeacher)):
return teacher.predict_proba_tensor(x)
raise ValueError(f"Unknown teacher type: {type(teacher)}")
# ─── SECTION 7: Replay Buffer ─────────────────────────────────────────────────
class ReplayBuffer:
"""
Experience replay buffer populated during warmup (Section 4.5).
Prevents catastrophic forgetting: the student is trained on a mix of
- 90% adversarial samples (hard, diverse, from adversarial training)
- 10% replay samples (typical, from warmup — stability anchors)
This ensures the student does not overfit to edge cases during adversarial
training while forgetting typical inputs.
"""
def __init__(self, max_size: int = 2000):
self.buffer = deque(maxlen=max_size)
self.max_size = max_size
def add(self, x: Tensor):
"""Add samples to the buffer."""
for i in range(x.shape[0]):
self.buffer.append(x[i].detach().cpu())
def sample(self, n: int, device: torch.device) -> Optional[Tensor]:
"""Sample n items from the buffer. Returns None if buffer is empty."""
if len(self.buffer) == 0:
return None
n = min(n, len(self.buffer))
samples = random.sample(list(self.buffer), n)
return torch.stack(samples, dim=0).to(device)
def __len__(self):
return len(self.buffer)
# ─── SECTION 8: TabKD Trainer ─────────────────────────────────────────────────
class TabKDTrainer:
"""
Full TabKD training pipeline implementing all three phases (Section 5.5):
Phase 0 — Warmup (30 epochs):
Train student on random samples, populate replay buffer.
Objective: Eq. 8 — standard KL divergence distillation.
Phase 1 — Bin Stabilization (200 epochs, temperature annealing):
Train bin learner to align with teacher decision boundaries.
Run boundary-focused generator to supply near-boundary samples.
Objective: Eq. 1 (bin loss) + Eq. 2 (boundary generator).
Freeze bin learner upon completion.
Phase 2 — Adversarial Distillation (400 epochs):
Generator maximizes interaction coverage + student hardness.
Student minimizes KL divergence from teacher.
Mix: 90% adversarial + 10% replay buffer samples.
Objective: Eq. 6 (generator) + Eq. 7 (student).
"""
def __init__(self, config: TabKDConfig, teacher, device: torch.device):
self.config = config
self.teacher = teacher
self.device = device
self.bin_learner = SoftBinLearner(config.n_features, config.n_bins).to(device)
self.generator = TabularGenerator(config).to(device)
self.student = TabularStudent(config).to(device)
self.replay_buffer = ReplayBuffer(config.replay_size)
self.bin_loss_fn = BinLearningLoss(config)
self.diversity_loss_fn = InteractionDiversityLoss(config)
self.gen_optimizer = torch.optim.Adam(self.generator.parameters(), lr=config.lr)
self.student_optimizer = torch.optim.Adam(self.student.parameters(), lr=config.lr)
self.bin_optimizer = torch.optim.Adam(self.bin_learner.parameters(), lr=config.lr * 0.5)
# Cosine annealing schedulers
total_epochs = config.warmup_epochs + config.bin_epochs + config.adv_epochs
self.gen_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.gen_optimizer, T_max=total_epochs
)
self.student_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.student_optimizer, T_max=total_epochs
)
self.history = []
def _generate_batch(self, n: int) -> Tensor:
"""Generate a batch of synthetic samples."""
return self.generator.sample(n, self.device)
def _teacher_probs(self, x: Tensor) -> Tensor:
"""Get teacher soft labels for a batch of samples."""
return get_teacher_probs(self.teacher, x)
def _annealed_temperature(self, epoch: int) -> float:
"""Linear temperature annealing during bin learning (Eq. 9)."""
tau = self.config.tau_start - (
(self.config.tau_start - self.config.tau_end) * epoch / max(1, self.config.bin_epochs)
)
return max(self.config.tau_end, tau)
def _warmup(self):
"""
Phase 0: Pre-train student on random samples.
Populates replay buffer for later catastrophic forgetting prevention.
(Eq. 8: L_warmup = KL(S(x) ‖ T(x)))
"""
print(f" [Phase 0] Warmup: {self.config.warmup_epochs} epochs...")
self.student.train()
total_loss = 0.0
for epoch in range(self.config.warmup_epochs):
x_rand = self._generate_batch(self.config.batch_size)
teacher_probs = self._teacher_probs(x_rand)
# KL divergence: KL(S(x) ‖ T(x))
log_student = self.student(x_rand)
loss = F.kl_div(log_student, teacher_probs, reduction="batchmean")
self.student_optimizer.zero_grad()
loss.backward()
self.student_optimizer.step()
total_loss += loss.item()
# Populate replay buffer with diverse warmup samples
self.replay_buffer.add(x_rand.detach())
avg_loss = total_loss / self.config.warmup_epochs
print(f" Warmup complete | avg KL loss: {avg_loss:.4f} | replay buffer: {len(self.replay_buffer)} samples")
def _bin_stabilization(self):
"""
Phase 1: Learn adaptive bin boundaries aligned with teacher decisions.
Temperature anneals from tau_start to tau_end across bin_epochs.
Bins are frozen after this phase completes.
(Eq. 1: L_bin; Eq. 2: L_gen^(1))
"""
print(f" [Phase 1] Bin Stabilization: {self.config.bin_epochs} epochs...")
self.bin_learner.train()
self.generator.train()
total_bin_loss = 0.0
for epoch in range(self.config.bin_epochs):
# Update temperature for annealing
tau = self._annealed_temperature(epoch)
self.bin_learner.set_temperature(tau)
# ── Generator: boundary-focused sampling (Eq. 2) ────────────────
x_gen = self._generate_batch(self.config.batch_size)
teacher_probs = self._teacher_probs(x_gen)
student_log_probs = self.student(x_gen)
# L_class-div: maximize class probability spread → class balance
student_probs = student_log_probs.exp()
class_mean = student_probs.mean(dim=0)
# Maximize entropy of class marginals
l_class_div = (class_mean * class_mean.log().clamp(-10)).sum()
# L_entropy: maximize teacher output entropy (target decision boundaries)
t_entropy = -(teacher_probs * (teacher_probs + 1e-10).log()).sum(dim=-1).mean()
l_entropy = -t_entropy # negative = minimize entropy = maximize boundary proximity
gen_loss_phase1 = (self.config.lambda_div * l_class_div
+ self.config.lambda_boundary * l_entropy)
self.gen_optimizer.zero_grad()
gen_loss_phase1.backward(retain_graph=True)
self.gen_optimizer.step()
# ── Bin Learner: align bins with teacher decision boundaries (Eq. 1)
x_gen_detach = x_gen.detach()
teacher_probs_detach = self._teacher_probs(x_gen_detach)
m = self.bin_learner(x_gen_detach)
bin_loss = self.bin_loss_fn(m, teacher_probs_detach)
self.bin_optimizer.zero_grad()
bin_loss.backward()
self.bin_optimizer.step()
total_bin_loss += bin_loss.item()
if (epoch + 1) % 50 == 0:
print(f" Bin epoch {epoch+1}/{self.config.bin_epochs} | bin_loss={bin_loss.item():.4f} | τ={tau:.3f}")
# Freeze bin learner
for p in self.bin_learner.parameters():
p.requires_grad = False
self.bin_learner.eval()
print(f" Bins frozen. avg_bin_loss={total_bin_loss/self.config.bin_epochs:.4f}")
def _adversarial_distillation(self):
"""
Phase 2: Adversarial generator + student training with frozen bins.
Generator maximizes interaction coverage + hardness (Eq. 6).
Student minimizes KL from teacher (Eq. 7).
Training mix: 90% adversarial + 10% replay.
"""
print(f" [Phase 2] Adversarial Distillation: {self.config.adv_epochs} epochs...")
n_replay = max(1, int(self.config.batch_size * self.config.replay_ratio))
n_adv = self.config.batch_size - n_replay
all_bin_assignments = [] # collect for coverage tracking
for epoch in range(self.config.adv_epochs):
self.generator.train()
self.student.train()
# ── Generate adversarial samples ──────────────────────────────────
x_adv = self._generate_batch(n_adv)
teacher_probs = self._teacher_probs(x_adv)
# Compute soft bin memberships with frozen bin learner
with torch.no_grad():
m = self.bin_learner(x_adv)
# ── Generator Loss (Eq. 6): L_gen = L_diversity + λ_hard · L_hardness
m_gen = self.bin_learner(x_adv) # gradients flow to generator via x_adv
l_diversity = self.diversity_loss_fn(m_gen)
student_log_probs = self.student(x_adv).detach()
# L_hardness (Eq. 5): reward generator for finding student failures
l_hardness = -F.kl_div(
student_log_probs, teacher_probs, reduction="batchmean"
)
gen_loss = (self.config.lambda_cov * l_diversity
+ self.config.lambda_hard * l_hardness)
self.gen_optimizer.zero_grad()
gen_loss.backward()
nn.utils.clip_grad_norm_(self.generator.parameters(), 5.0)
self.gen_optimizer.step()
# ── Student Loss (Eq. 7): minimize KL(S(x) ‖ T(x)) ───────────────
x_adv_fresh = self._generate_batch(n_adv).detach()
teacher_probs_fresh = self._teacher_probs(x_adv_fresh)
# Mix with replay buffer samples
replay = self.replay_buffer.sample(n_replay, self.device)
if replay is not None:
x_mix = torch.cat([x_adv_fresh, replay], dim=0)
t_replay = self._teacher_probs(replay)
t_mix = torch.cat([teacher_probs_fresh, t_replay], dim=0)
else:
x_mix, t_mix = x_adv_fresh, teacher_probs_fresh
student_loss = F.kl_div(self.student(x_mix), t_mix, reduction="batchmean")
self.student_optimizer.zero_grad()
student_loss.backward()
nn.utils.clip_grad_norm_(self.student.parameters(), 5.0)
self.student_optimizer.step()
# Track bin assignments for coverage measurement
with torch.no_grad():
hard_assignments = self.bin_learner.hard_assignment(x_adv_fresh)
all_bin_assignments.append(hard_assignments.cpu())
self.gen_scheduler.step()
self.student_scheduler.step()
if (epoch + 1) % 100 == 0:
all_ba = torch.cat(all_bin_assignments, dim=0)
coverage = self.diversity_loss_fn.compute_coverage(all_ba)
print(
f" Adv epoch {epoch+1}/{self.config.adv_epochs} | "
f"gen_loss={gen_loss.item():.4f} | "
f"student_kl={student_loss.item():.4f} | "
f"coverage={coverage:.1%}"
)
self.history.append({
"epoch": epoch + 1,
"gen_loss": gen_loss.item(),
"student_kl": student_loss.item(),
"coverage": coverage,
})
def train(self):
"""Execute full three-phase TabKD training."""
print(f"\n{'='*60}")
print(f" TabKD Training | {self.config.n_features} features | {self.config.n_bins} bins")
print(f" Phases: Warmup({self.config.warmup_epochs}) → "
f"Bins({self.config.bin_epochs}) → Adv({self.config.adv_epochs})")
print(f"{'='*60}\n")
self._warmup()
self._bin_stabilization()
self._adversarial_distillation()
print("\n TabKD training complete.")
@torch.no_grad()
def evaluate(self, X_test: Tensor, y_test: Tensor) -> Dict[str, float]:
"""
Evaluate student model on held-out test data.
Returns accuracy, F1, teacher-student agreement, and coverage.
(Eqs. 10–13 in paper)
"""
self.student.eval()
student_preds = self.student.predict(X_test)
teacher_preds = get_teacher_probs(self.teacher, X_test).argmax(dim=-1)
y_np = y_test.cpu().numpy()
s_np = student_preds.cpu().numpy()
t_np = teacher_preds.cpu().numpy()
acc = (s_np == y_np).mean()
# F1 score (binary)
tp = ((s_np == 1) & (y_np == 1)).sum()
fp = ((s_np == 1) & (y_np == 0)).sum()
fn = ((s_np == 0) & (y_np == 1)).sum()
precision = tp / max(1, tp + fp)
recall = tp / max(1, tp + fn)
f1 = 2 * precision * recall / max(1e-8, precision + recall)
# Teacher-student agreement (Eq. 12)
agreement = (s_np == t_np).mean()
# Coverage (Eq. 13)
hard_assignments = self.bin_learner.hard_assignment(X_test)
coverage = self.diversity_loss_fn.compute_coverage(hard_assignments.cpu())
return {
"accuracy": float(acc) * 100,
"f1": float(f1) * 100,
"agreement": float(agreement) * 100,
"coverage": float(coverage) * 100,
}
# ─── SECTION 9: Dataset Helpers & Evaluation ──────────────────────────────────
def load_benchmark_dataset(
name: str = "adult",
n_samples: int = 1000,
test_ratio: float = 0.2,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Load one of the four TabKD benchmark datasets (Section 5.2).
Production: download real datasets from UCI Machine Learning Repository:
- Adult: https://doi.org/10.24432/C5XW20
- Breast Cancer: https://doi.org/10.24432/C51P4M
- Credit: https://doi.org/10.24432/C55S3H
- Mushroom: https://doi.org/10.24432/C5959T
Production loading example (requires sklearn):
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import StandardScaler
data = load_breast_cancer()
X = StandardScaler().fit_transform(data.data)
y = data.target
This function generates synthetic data with the same shape
as each real dataset for demonstration purposes.
Returns
-------
X_train, X_test, y_train, y_test — all as float32 Tensors
"""
dataset_specs = {
"adult": {"n_features": 14, "name": "Adult Income (48K, 14 features)"},
"credit": {"n_features": 23, "name": "Credit Default (30K, 23 features)"},
"breast_cancer": {"n_features": 30, "name": "Breast Cancer (569, 30 features)"},
"mushroom": {"n_features": 22, "name": "Mushroom (8K, 22 features)"},
}
if name not in dataset_specs:
raise ValueError(f"Unknown dataset '{name}'. Choose from: {list(dataset_specs.keys())}")
spec = dataset_specs[name]
n_features = spec["n_features"]
# Try to load real datasets via sklearn if available
if name == "breast_cancer":
try:
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
data = load_breast_cancer()
X = StandardScaler().fit_transform(data.data).astype(np.float32)
y = data.target.astype(np.int64)
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=test_ratio, stratify=y, random_state=42)
print(f" Loaded real Breast Cancer dataset: {X.shape[0]} samples, {X.shape[1]} features")
return (torch.tensor(X_tr), torch.tensor(X_te),
torch.tensor(y_tr), torch.tensor(y_te))
except ImportError:
pass
# Synthetic fallback
torch.manual_seed(42)
X = torch.randn(n_samples, n_features)
# Synthetic labels: XOR of first two features (non-linear interaction)
y = ((X[:, 0] > 0).long() ^ (X[:, 1] > 0).long())
n_test = int(n_samples * test_ratio)
n_train = n_samples - n_test
X_train, X_test = X[:n_train], X[n_train:]
y_train, y_test = y[:n_train], y[n_train:]
print(f" Loaded synthetic '{name}' data: {n_train} train / {n_test} test | {n_features} features")
return X_train, X_test, y_train, y_test
def run_tabkd_experiment(
dataset_name: str = "breast_cancer",
teacher_type: str = "nn",
n_samples: int = 800,
warmup_epochs: int = 5,
bin_epochs: int = 20,
adv_epochs: int = 40,
n_bins: int = 8,
verbose: bool = True,
) -> Dict[str, float]:
"""
Run a complete TabKD experiment for a dataset + teacher combination.
Parameters
----------
dataset_name : "adult" | "credit" | "breast_cancer" | "mushroom"
teacher_type : "nn" | "xgboost" | "rf" | "tabtransformer"
n_samples : total samples (use full dataset in production)
warmup_epochs : Stage 0 epochs (paper: 30)
bin_epochs : Stage 1 epochs (paper: 200)
adv_epochs : Stage 2 epochs (paper: 400)
n_bins : bins per feature (paper: 8)
Returns
-------
dict with accuracy, f1, agreement, coverage metrics
"""
dataset_spec = {
"adult": 14, "credit": 23, "breast_cancer": 30, "mushroom": 22
}
n_features = dataset_spec.get(dataset_name, 14)
device = torch.device("cpu")
print(f"\n{'─'*60}")
print(f" Dataset: {dataset_name} | Teacher: {teacher_type} | Bins: {n_bins}")
print(f"{'─'*60}")
# Load dataset
X_train, X_test, y_train, y_test = load_benchmark_dataset(dataset_name, n_samples)
n_features = X_train.shape[1]
# Build and train teacher
if teacher_type == "nn":
teacher = NeuralNetworkTeacher(n_features, n_classes=2).to(device)
teacher.train_on_data(X_train, y_train, epochs=15)
teacher.eval()
teacher_acc = (teacher.predict(X_test) == y_test).float().mean().item() * 100
elif teacher_type == "xgboost":
teacher = XGBoostTeacher()
teacher.fit(X_train.numpy(), y_train.numpy())
teacher_acc = (teacher.predict(X_test) == y_test).float().mean().item() * 100
elif teacher_type == "rf":
teacher = RandomForestTeacher()
teacher.fit(X_train.numpy(), y_train.numpy())
teacher_acc = (teacher.predict(X_test) == y_test).float().mean().item() * 100
elif teacher_type == "tabtransformer":
teacher = TabTransformerTeacher(n_features, n_classes=2).to(device)
teacher.train_on_data(X_train, y_train, epochs=15)
teacher.eval()
teacher_acc = (teacher.predict(X_test) == y_test).float().mean().item() * 100
else:
raise ValueError(f"Unknown teacher type: {teacher_type}")
print(f" Teacher accuracy: {teacher_acc:.1f}%")
# Build TabKD config
config = TabKDConfig(
n_features=n_features,
n_classes=2,
n_bins=n_bins,
warmup_epochs=warmup_epochs,
bin_epochs=bin_epochs,
adv_epochs=adv_epochs,
)
# Run TabKD training
trainer = TabKDTrainer(config, teacher, device)
trainer.train()
# Evaluate
results = trainer.evaluate(X_test.to(device), y_test.to(device))
results["teacher_accuracy"] = teacher_acc
print(f"\n Results | Acc={results['accuracy']:.1f}% | F1={results['f1']:.1f}% | "
f"Agreement={results['agreement']:.1f}% | Coverage={results['coverage']:.1f}%")
print(f" (Paper target: {teacher_type}+{dataset_name} → see Table 2 for expected values)")
return results
# ─── SECTION 10: Smoke Test ────────────────────────────────────────────────────
if __name__ == "__main__":
print("=" * 60)
print("TabKD — Full Framework Smoke Test")
print("=" * 60)
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cpu")
# ── 1. Soft Bin Learner ──────────────────────────────────────────────────
print("\n[1/7] Soft Bin Learner — shape and range checks...")
config = TabKDConfig(n_features=14, n_classes=2, n_bins=8)
bin_learner = SoftBinLearner(config.n_features, config.n_bins).to(device)
x_test = torch.randn(32, config.n_features)
m = bin_learner(x_test)
assert m.shape == (32, 14, 8), f"Expected (32,14,8), got {m.shape}"
assert (m >= 0).all() and (m <= 1).all(), "Memberships out of [0,1]"
assert torch.allclose(m.sum(dim=-1), torch.ones(32, 14), atol=1e-5), "Memberships must sum to 1 per feature"
print(f" Membership tensor shape: {tuple(m.shape)} ✓")
print(f" Values in [0,1], sum to 1 per feature ✓")
# ── 2. Bin Learning Loss ─────────────────────────────────────────────────
print("\n[2/7] Bin Learning Loss...")
bin_loss_fn = BinLearningLoss(config)
teacher_probs = torch.softmax(torch.randn(32, 2), dim=-1)
loss_bin = bin_loss_fn(m, teacher_probs)
assert torch.isfinite(loss_bin), "Bin loss is not finite"
print(f" L_bin = {loss_bin.item():.4f} (finite, differentiable) ✓")
# ── 3. Interaction Diversity Loss ────────────────────────────────────────
print("\n[3/7] Interaction Diversity Loss & Coverage...")
div_loss_fn = InteractionDiversityLoss(config)
loss_div = div_loss_fn(m)
assert torch.isfinite(loss_div), "Diversity loss is not finite"
n_pairs = config.n_features * (config.n_features - 1) // 2
print(f" Feature pairs: {n_pairs} | L_diversity = {loss_div.item():.4f} ✓")
hard_assignments = bin_learner.hard_assignment(x_test)
coverage = div_loss_fn.compute_coverage(hard_assignments.cpu())
print(f" Pairwise coverage on 32 samples: {coverage:.1%}")
total_cells = n_pairs * (config.n_bins ** 2)
print(f" Total pairwise bin cells: {total_cells:,}")
# ── 4. Generator & Student Forward Pass ──────────────────────────────────
print("\n[4/7] Generator & Student Networks...")
gen = TabularGenerator(config).to(device)
student = TabularStudent(config).to(device)
x_gen = gen.sample(64, device)
assert x_gen.shape == (64, 14), f"Generator output shape mismatch: {x_gen.shape}"
log_probs = student(x_gen)
assert log_probs.shape == (64, 2), f"Student output shape mismatch: {log_probs.shape}"
assert torch.allclose(log_probs.exp().sum(dim=-1), torch.ones(64), atol=1e-5)
print(f" Generator: noise({config.noise_dim}) → samples({config.n_features}) ✓")
print(f" Student: samples({config.n_features}) → log_probs({config.n_classes}) ✓")
n_gen = sum(p.numel() for p in gen.parameters())
n_stu = sum(p.numel() for p in student.parameters())
print(f" Generator params: {n_gen:,} | Student params: {n_stu:,}")
# ── 5. Neural Network Teacher ─────────────────────────────────────────────
print("\n[5/7] Teacher Training (Neural Network)...")
teacher_nn = NeuralNetworkTeacher(config.n_features, config.n_classes).to(device)
X_mock = torch.randn(200, config.n_features)
y_mock = ((X_mock[:, 0] > 0).long() ^ (X_mock[:, 1] > 0).long())
loss_final = teacher_nn.train_on_data(X_mock, y_mock, epochs=10)
teacher_nn.eval()
teacher_probs_mock = get_teacher_probs(teacher_nn, X_mock[:8])
assert teacher_probs_mock.shape == (8, 2)
assert torch.allclose(teacher_probs_mock.sum(dim=-1), torch.ones(8), atol=1e-5)
print(f" Teacher trained | final loss: {loss_final:.4f}")
print(f" Teacher probs shape: {tuple(teacher_probs_mock.shape)} | sum to 1 ✓")
# ── 6. Replay Buffer ─────────────────────────────────────────────────────
print("\n[6/7] Replay Buffer...")
buf = ReplayBuffer(max_size=500)
for _ in range(10):
buf.add(torch.randn(20, config.n_features))
sample = buf.sample(32, device)
assert sample is not None and sample.shape[1] == config.n_features
print(f" Buffer size: {len(buf)} | Sample shape: {tuple(sample.shape)} ✓")
# ── 7. Full Training Run (short) ──────────────────────────────────────────
print("\n[7/7] Full TabKD Training Run (Breast Cancer, NN teacher, mini-epochs)...")
results = run_tabkd_experiment(
dataset_name="breast_cancer",
teacher_type="nn",
n_samples=400,
warmup_epochs=3,
bin_epochs=10,
adv_epochs=20,
n_bins=8,
)
print("\n" + "=" * 60)
print("✓ All TabKD checks passed. Framework is ready for use.")
print("=" * 60)
print("""
Next steps to reproduce paper results:
1. Download real datasets from UCI Repository:
Adult: https://doi.org/10.24432/C5XW20
Breast Cancer: https://doi.org/10.24432/C51P4M
Credit: https://doi.org/10.24432/C55S3H
Mushroom: https://doi.org/10.24432/C5959T
2. Install full dependencies:
pip install scikit-learn xgboost tab-transformer-pytorch
3. Run full training with paper hyperparameters:
config = TabKDConfig(
n_bins=8, batch_size=128, lr=0.001,
warmup_epochs=30, bin_epochs=200, adv_epochs=400,
lambda_cov=10.0, lambda_hard=2.0,
)
4. Apply teacher-specific temperature schedules (Table 1):
NN: tau_start=1.0, tau_end=0.05
Random Forest: tau_start=1.2, tau_end=0.08
XGBoost: tau_start=1.5, tau_end=0.10
TabTransformer: tau_start=1.2, tau_end=0.08
5. Access open-source implementation:
https://anonymous.4open.science/r/int_div-0413/README.md
6. Expected results (agreement %):
Breast Cancer + NN: ~95.6%
Breast Cancer + XGBoost: ~96.5%
Credit + NN: ~97.0%
Mushroom + NN: ~96.5%
""")
Read the Full Paper & Access the Code
The complete study — including all 16 dataset-teacher result tables, coverage-vs-agreement correlation plots across all teacher types, ablation tables, and the open-source implementation — is available on arXiv and the anonymous code repository.
Pereira, S. N., Khadka, K., & Lei, Y. (2026). TabKD: Tabular Knowledge Distillation through Interaction Diversity of Learned Feature Bins. arXiv:2603.15481v1 [cs.LG]. University of Texas at Arlington.
This article is an independent editorial analysis of peer-reviewed research. The Python implementation is an educational adaptation covering all described algorithmic components. For exact replication, use the official repository with real benchmark datasets downloaded from the UCI Machine Learning Repository.
Explore More on AI Trend Blend
If this article caught your attention, here is more of what we cover — from model compression and privacy-preserving AI to computer vision, continual learning, and tabular machine learning.
