PRECTR-V2: How Alibaba Solved Cold-Start, Exposure Bias, and Frozen Encoders in One Unified Search Ranking Framework | AI Trend Blend

PRECTR-V2: How Alibaba Solved Cold-Start, Exposure Bias, and a Frozen Encoder — All in One Unified Search Ranking Framework

Engineers at Alibaba’s Xianyu platform built three interlocking improvements on top of their prior unified relevance-CTR model, ultimately achieving a 1.39% lift in per-capita orders and a 3.18% GMV increase in live A/B testing.

CTR Prediction Search Relevance Cross-User Preference Mining Exposure Bias Correction Knowledge Distillation LLM Distillation Mixture of Experts Cold-Start Pairwise Ranking
PRECTR-V2: Unified Relevance–CTR Framework Cross-User Preference Mining · Exposure Debias · LLM-Distilled Lightweight Encoder Cold-Start Preference Mining Relevance Exposure Debias LLM-Distilled Encoder MoE Incentive Module Cross-user MHTA attention r^user + r^cate → τ(x) Expert+ (softplus) / Expert− (soft−) Pairwise Debias Loss Noise injection → fake_rsl R_debias = w(x) · pairwise margin clip + dyn truncation Lightweight Transformer 3-layer · ~2M params Qwen-7B teacher distillation End-to-end CTR fine-tuning Offline AUC 0.7674 Online GMV lift +3.18% Orders per user +1.39% Cao, Chen, He, Han, Chen · Alibaba Group / Xianyu · arXiv:2602.20676 (2026)
Figure 1: PRECTR-V2 overview — three coordinated improvements to the unified relevance-CTR pipeline: a cold-start cross-user preference mining module with MoE incentive scoring, a pairwise debias loss with margin clipping and dynamic truncation weighting, and a Qwen-7B-distilled lightweight transformer that replaces the frozen BERT encoder. Deployed on Alibaba’s Xianyu platform with over 20% of live traffic.

In industrial search, relevance and click-through rate have traditionally been handled by separate systems — trained independently, combined with hand-tuned rules, and perpetually misaligned with each other. Alibaba’s prior work, PRECTR, tried to unify them. It worked. But it left three wounds untreated: cold-start users who barely left a behavioral trace, a training set so dominated by high-relevance exposed items that the model had never learned to score the broader candidate pool, and a frozen BERT encoder that could not be updated to close the gap between semantic representation and click prediction. PRECTR-V2 addresses all three at once — and the results, both offline and in a live A/B test on Xianyu’s search system, are substantial.

What makes this paper interesting beyond its benchmark numbers is the specificity of its diagnoses. The authors do not simply propose a larger model or a richer encoder. They trace each failure mode to a concrete architectural or data-distribution cause, then design a focused intervention for each one. The cold-start gap leads to a cross-user behavior mining strategy grounded in query category similarity. The exposure bias leads to a synthetic hard negative construction procedure with two regularization mechanisms to prevent over-correction. The frozen encoder leads to a distillation pipeline that squeezes Qwen-7B’s representational capacity into a three-layer transformer compact enough for production inference — and then jointly fine-tunes it end-to-end with the CTR task.

Together these three components add up to an AUC improvement of 0.0093 over PRECTR on Alibaba’s internal Xianyu dataset, a +1.39% lift in per-capita orders, and a +3.18% GMV increase. The model has since been fully deployed across Xianyu’s search system — a meaningful signal that the improvements survived the transition from offline ablation to real traffic.


Background: What PRECTR Got Right, and Where It Fell Short

Understanding PRECTR-V2 requires a clear picture of what it is building on. PRECTR proposed a key decomposition: the probability of a click can be modeled as the product of a conditional click probability (given a relevance score level) and the marginal probability of that relevance level, summed across all levels:

Eq. 1 — PRECTR Click Model $$P(\text{click}=1|\mathbf{x}) = \sum_{i=1}^{k} P(\text{click}=1|\text{rsl}=i,\mathbf{x})\cdot P(\text{rsl}=i|\mathbf{x})$$

This formulation elegantly couples relevance modeling and CTR prediction within a single architecture. Rather than combining them post-hoc with a fixed rule, PRECTR learns the coupling automatically. It also computes a personalized relevance incentive score \(\tau(\mathbf{x})\) derived from the user’s historical search behavior, which adjusts the base CTR estimate before final ranking.

This approach was genuinely novel and produced strong results. But the architecture exposed three vulnerabilities that only became visible at Xianyu’s scale. First, the incentive score \(\tau\) depends on historical search interactions — but a significant share of Xianyu’s user base is either new or low-activity. For these users, the behavioral sequence is sparse or nonexistent, and the incentive computation degrades gracefully only to a near-uniform prior. The personalization signal, which drove much of PRECTR’s improvement, simply wasn’t there.

Second, PRECTR’s training data was drawn from exposure logs, and those logs had a strong selection effect: over 80% of exposed samples carried high relevance labels. The full candidate pool at inference time, however, contains many lower-relevance items that the model had never been trained to discriminate. A model optimized almost entirely on high-relevance data develops a blind spot precisely where the fine-grained ranking work gets done.

Third, and perhaps most structurally limiting: PRECTR used a frozen BERT model to generate text embeddings. Freezing BERT was necessary to meet production latency requirements — BERT-base’s 110 million parameters are simply too expensive to fine-tune online. But this architectural choice meant that the text representations feeding into the CTR model could never adapt to click behavior. No matter how much the downstream model learned, the upstream encoder remained static.

Key Takeaway

PRECTR-V2 targets three distinct failure modes of its predecessor: sparse behavioral data for cold-start users, distribution mismatch between training exposures and inference candidates, and the inability to jointly optimize text representation and CTR prediction under latency constraints.


The Full Architecture: Three Modules, One Joint Model

Structural Causal Model diagram showing nodes S, U, F, X, Y, E
Figure 2: PRECTR-V2 full architecture. The Base+RSL module (left, gold) computes the factored click probability using a shared Lightweight Encoder. The Personalized Incentive Module (right, warm white) adds cross-user MHTA attention and a Mixture-of-Experts incentive score τ. A separate Debias Module (bottom, teal) applies pairwise ranking loss with margin clamping and dynamic truncation to correct training-inference distribution mismatch. All three modules are jointly trained.

Module 1: Cross-User Relevance Preference Mining

The core insight behind this module is a behavioral observation: users’ sensitivity to relevance is strongly correlated with the category of their query. Across Xianyu’s transaction logs, the average relevance level of items that users click varies markedly by category. A user searching for electronics expects a tighter relevance match than a user browsing vintage clothing. This category-level pattern is stable and observable in aggregate even when individual users leave almost no behavioral data.

PRECTR-V2 exploits this structure by replacing the sparse individual sequence of a cold-start user with a hybrid sequence: the user’s own (possibly empty) behavioral history plus a set of behaviors sampled from 50 globally active users in the same query category. The selection is further refined by query text similarity — among the sampled users’ historical interactions, only the top-k by cosine similarity to the current query are retained as supplementary context.

The merged sequence is then processed through a Multi-Head Target-Attention (MHTA) layer. The current query serves as the attention query; the historical query embeddings from the personal and cross-user sequences serve as keys; and the joint query-item relevance embeddings serve as values. This produces two representations: \(r^{user}\), capturing the individual user’s preferences, and \(r^{cate}\), capturing the query-category-level signal extracted from the global pool.

Eq. 2 — MHTA for Preference Extraction $$r^{user} = \text{Attn}(Q, K_u, V_u) = \text{softmax}\!\left(\frac{QK_u^\top}{\sqrt{d}}\right)V_u$$ $$r^{cate} = \text{Attn}(Q, K_q, V_q) = \text{softmax}\!\left(\frac{QK_q^\top}{\sqrt{d}}\right)V_q$$

These two representations feed into a Mixture of Experts (MoE) network that computes the final incentive score \(\tau\). A gating network takes \(r^{user}\) as input and produces a two-dimensional soft label \([w_1, w_2]\) reflecting the user’s relevance sensitivity. Two expert networks — one with a softplus-style activation for users who favor high-relevance items, and one with its negated counterpart for users who are more tolerant of lower-relevance matches — process the combined context. The final score is their weighted combination:

Eq. 3 — MoE Incentive Score $$\tau = w_1 \cdot f_+(\mathbf{x}) + w_2 \cdot f_-(\mathbf{x}), \quad f_+(\mathbf{x}) = \log(1 + e^{E_+(\mathbf{x})}), \quad f_-(\mathbf{x}) = -\log(1 + e^{E_-(\mathbf{x})})$$

This design is elegant because it handles the full spectrum of user types without requiring explicit labels for relevance preference. The gating network learns to route cold-start users toward the category-level expert signal and active users toward their personal history. For a user who has never searched before, \(r^{user}\) carries almost no information and the gating network relies almost entirely on \(r^{cate}\) — exactly the right behavior.

Key Takeaway

Cross-user preference mining does not try to invent behavioral history for cold-start users. Instead, it borrows query-category-level preference patterns from globally active users, ranked by query text similarity, and integrates them through a Mixture of Experts gate that automatically weights individual vs. collective signals.


Module 2: Relevance Exposure Debias

The distribution mismatch problem here is a concrete consequence of how industrial ranking systems work. In Xianyu’s search pipeline, items shown to users must pass a relevance filter — meaning the items that generate training labels are already pre-selected to be relevant. Over 80% of exposed samples carry high relevance labels. But when the fine-grained ranking model is deployed at inference time, it scores a much broader candidate pool produced by coarser-stage filtering. The model has never seen what low-relevance items look like, so it has no calibrated opinion about them.

The obvious fix — sampling low-relevance items from the unexposed pool and adding them to training — has a serious side effect: it wrecks the positive-to-negative ratio, introduces low-quality items with poor item-side features, and makes it difficult to isolate the relevance signal from other confounders. PRECTR-V2 takes a more surgical approach: instead of using real unexposed items, it constructs synthetic hard negatives from existing positives.

For each clicked, highly-relevant sample (click=1, rsl=4), the method generates a synthetic counterpart by randomly downgrading its relevance label to 1, 2, or 3 according to preset probabilities, and simultaneously perturbing its item text embedding with additive Gaussian noise:

Eq. 4 — Synthetic Hard Negative Construction $$Q^{\text{fake}}_{\text{emb}} = Q_{\text{emb}} + \epsilon, \quad \epsilon_i \overset{\text{i.i.d.}}{\sim} \mathcal{N}(0,1)$$

Because all other features are held constant, the only difference between the positive and its synthetic counterpart is the relevance dimension. This makes the signal clean: the model is explicitly taught that the same item, made slightly less relevant, should receive a lower score. A pairwise ranking loss enforces this ordering:

Eq. 5 — Basic Pairwise Ranking Loss $$\mathcal{L}_{\text{pair}} = \sum_{i=1}^{N} \log\!\left(1 + e^{-(f(x^+_i) – f(x^-_i))}\right)$$

The authors observed, however, that naive application of this loss causes score divergence: positive scores keep climbing, negative scores keep falling, and the gap between them grows unboundedly throughout training. This inflates the model’s raw score predictions and degrades PCOC — the ratio of predicted click probability to actual observed click rate — which is a critical calibration metric in production systems where predicted probabilities drive business decisions.

Two mechanisms address this. First, a margin clip: the pairwise loss is only computed when the score gap between positive and negative is below a threshold (set to 0.075, equal to the observed average deviation). Once the gap exceeds this threshold, the gradient contribution goes to zero. Second, a dynamic truncation weight: if the average batch score exceeds a critical threshold (set to 0.08, matching the platform’s online CTR), the loss weight drops to zero entirely, preventing the model from continuing to push scores upward beyond realistic range:

Eq. 6 — Regularized Debias Loss $$\mathcal{R}_{\text{debias}} = w(\mathbf{x}) \cdot \sum_{i=1}^{N}\log\!\left(1 + e^{\max(0,\,\text{margin}-(f(x^+_i)-f(x^-_i)))}\right)$$ $$w(\mathbf{x}) = \begin{cases} w & \text{if } \text{mean}(f(x^+)) < \text{threshold} \\ 0 & \text{if } \text{mean}(f(x^+)) \geq \text{threshold} \end{cases}$$

The result is that PRECTR-V2’s PCOC deviation (measured as deviation from the ideal value of 1.0) is 1.7%, compared to 2.3% for the baseline — a meaningful improvement in calibration precision while simultaneously correcting the exposure bias. The two regularization mechanisms are doing exactly what they were designed for: keeping ranking quality gains from coming at the expense of absolute prediction accuracy.

“It is unreasonable to infinitely magnify the relative order between x+ and x−. The margin penalty and dynamic truncation together ensure that debiasing improves ranking without inflating scores beyond the range that production CTR calibration requires.” — Cao et al., arXiv:2602.20676 (2026)

Module 3: LLM-Distilled CTR-Aligned Lightweight Encoder

Replacing a frozen BERT encoder with something jointly trainable sounds straightforward — but the engineering challenge is severe. BERT-base has 110 million parameters. Running it through backpropagation during CTR fine-tuning at the scale of 1.6 billion daily training samples is simply not feasible in production. The solution in PRECTR-V2 is to build something much smaller that has been pre-loaded with BERT-quality (or better) representational capability before it ever touches the CTR data.

The lightweight encoder consists of just three stacked Transformer layers, totaling approximately 2 million parameters — a 55× parameter reduction from BERT-base. The compression is achieved through a two-stage pretraining pipeline before any CTR-specific fine-tuning occurs.

Stage 1: Supervised Fine-Tuning on Relevance Classification

In the first stage, the encoder is trained as a text relevance classifier using query-item pairs with known relevance score levels (rsl ∈ {1,2,3,4}) sampled from the exposure space. A projection network \(M\) maps the encoder’s output to a four-class distribution, and the whole system is trained with standard cross-entropy loss against the relevance label. This gives the encoder a grounded understanding of what relevance means in Xianyu’s product domain specifically — not general language understanding, but the particular semantic structure that distinguishes “strongly relevant” from “weakly relevant” in second-hand goods search.

Stage 2: LLM Embedding Distillation

The second pretraining signal is knowledge distillation from Qwen-7B. For each query-item pair in the training set, Qwen-7B generates a text representation. The lightweight encoder is then trained to minimize the MSE between its output and Qwen-7B’s representation. This transfers the semantic richness of a billion-parameter language model into the three-layer encoder without requiring Qwen-7B to be present at inference time.

The distillation is enhanced by Retrieval-Augmented Generation (RAG): before producing a query-item pair’s embedding, Qwen-7B is provided with highly similar samples and known relevance edge cases as template context. This improves the quality of Qwen-7B’s embeddings and therefore the quality of the distillation target — the teacher is helping the student more carefully, not just generating a generic embedding.

Eq. 7 — Combined Pretraining Objective $$\mathcal{R}_{\text{overall}} = \mathcal{R}_{\text{Distill}}(\theta) + \mathcal{R}_{\text{SFT}}(\theta)$$ $$\mathcal{R}_{\text{Distill}}(\theta) = \text{MSE}(g(\mathbf{x}),\, T(\mathbf{x};\theta))$$

Following pretraining, the encoder is embedded into the full PRECTR-V2 model and the entire system is fine-tuned end-to-end on CTR data. The encoder’s learning rate is reduced to avoid catastrophic forgetting of the pretraining representations. This final joint optimization is what PRECTR could not do with frozen BERT — and it is what allows the semantic representations to actually adapt to the click signal.

Key Takeaway

The lightweight encoder achieves competitive representational quality through a two-step process: domain-specific supervised fine-tuning on relevance classification, followed by Qwen-7B embedding distillation with RAG-enhanced teacher embeddings. The result is a 2M-parameter encoder that can be jointly optimized with the CTR model in production.


Experimental Results: Offline and Online

PRECTR-V2 vs. Baselines — AUC and GAUC (Xianyu Dataset) Bar length proportional to RelaImpr relative to Wide&Deep baseline AUC LR 0.6795 (−29.1%) DNN 0.7541 (+0.35%) W&D 0.7532 (base) DIN 0.7561 (+1.15%) PRECTR 0.7581 (+1.93%) V2 0.7674 (+5.61%) ↑ SOTA GAUC LR 0.6347 (−26.6%) DNN 0.6863 (+1.47%) W&D 0.6836 (base) DIN 0.6875 (+2.21%) PRECTR 0.6892 (+3.05%) V2 0.6933 (+5.28%) ↑ SOTA Online A/B Test (Xianyu, >20% live traffic) Per-capita orders: +1.39% · GMV: +3.18% · Fully deployed Ablation — RI vs PRECTR Cold-Start: +1.08% · Debias: +2.98% · Encoder: +3.13% · All: +3.60% Cao, Chen, He, Han, Chen · Alibaba Group / Xianyu · arXiv:2602.20676 (2026)
Figure 3: Offline AUC and GAUC comparison across baseline models. PRECTR-V2 achieves 0.7674 AUC (5.61% RelaImpr over Wide&Deep) and 0.6933 GAUC (5.28% RelaImpr), outperforming all eight baselines. The ablation row (bottom right) shows relative improvement from PRECTR baseline when each component is added — all three modules contribute positively, with the debias and encoder modules providing the largest individual gains.

Offline Comparison

The offline evaluation used 9 days of click logs from Xianyu, with seven days for training and two for evaluation. At 1.6 billion daily training records, this is a genuinely large-scale industrial evaluation — not a standard academic benchmark but the production data the model will actually face. The comparison covered eight baseline models spanning logistic regression, deep MLP approaches, factorization machine variants, deep interest networks, and the prior PRECTR framework.

Method AUC RelaImpr (AUC) GAUC RelaImpr (GAUC)
Wide&Deep 0.7532 0.00% (base) 0.6836 0.00% (base)
DIN 0.7561 +1.15% 0.6875 +2.21%
SuKD 0.7524 −0.31% 0.6851 +0.81%
PRECTR 0.7581 +1.93% 0.6892 +3.05%
PRECTR-V2 (w/o Cold-Start) 0.7609 +1.08%* 0.6908 +0.85%*
PRECTR-V2 (w/o Debias) 0.7658 +2.98%* 0.6928 +1.91%*
PRECTR-V2 (w/o LLM Encoder) 0.7662 +3.13%* 0.6927 +1.85%*
PRECTR-V2 (Full) 0.7674 +5.61% 0.6933 +5.28%

Table 1: Offline comparison results. *Ablation RI values are measured relative to PRECTR, not Wide&Deep. All three components contribute positively; the Cold-Start Preference Mining module provides the single largest individual contribution.

Online A/B Testing

PRECTR-V2 was deployed to Xianyu’s live search system using an A/B testing framework that assigned users to control and experimental groups via MD5 hashing of device IDs. The experimental group received over 20% of total platform traffic — a substantial deployment for a model change. Over the test period, the experimental group showed a 1.39% improvement in per-capita orders and a 3.18% increase in GMV relative to the control. These are strong online gains for an incremental model update at this scale, and they resulted in the full deployment of PRECTR-V2 across Xianyu’s search system.

The search relevance impact deserves an honest note: PRECTR-V2 shows a 1.09% relative increase (0.15% absolute) in the rate of irrelevant items among top-10 results when measured through manual evaluation. The authors attribute this to the personalization aspects of the incentive scoring — the model is trading a small amount of strict relevance for better user-preference alignment. Whether this tradeoff is acceptable depends on the platform’s objectives, but the authors are transparent about it rather than simply reporting the positive metrics.


What This Work Tells Us About Industrial ML at Scale

PRECTR-V2 is interesting not just as a system but as a case study in how industrial ML teams think about improvement. Each of the three problems it addresses was identifiable in the original PRECTR deployment — but identifying them required both careful data analysis (the 80% high-relevance exposure observation, the category-relevance correlation) and production experience (the PCOC degradation risk from naive pairwise training, the latency constraints that ruled out full BERT fine-tuning).

The cross-user preference mining module is a good example of this. A purely algorithmic approach might have tried to handle cold-start users by augmenting the training objective with a meta-learning component or a contrastive regularizer. The PRECTR-V2 solution is more pragmatic: look at what behavioral signals exist at the population level, find an appropriate granularity (query category), and build a retrieval mechanism to bring that signal to individual users at inference time. This is a design philosophy that reflects engineering constraints as much as algorithmic elegance.

The exposure debias module is similarly calibrated. The margin clipping and dynamic truncation weighting were not in the original formulation — they were added in response to observed training dynamics that showed score divergence over time. The authors monitored the model’s predictions during training and adjusted the loss design based on what they saw. This kind of iterative tightening of an objective function based on empirical signal is standard practice in production ML but rarely documented in papers; here it is fully described and quantitatively validated.

The lightweight encoder story is perhaps the most transferable contribution. The challenge of replacing a large frozen pretrained model with something jointly trainable is nearly universal in industrial deep learning. PRECTR-V2’s approach — combine SFT on domain-specific labels with large-model distillation using RAG-enhanced teacher embeddings, then fine-tune end-to-end at reduced learning rate — is a template that applies well beyond recommendation systems. It is a principled way to compress representational quality into a deployable package without sacrificing the benefits of joint optimization.

The limitations are worth acknowledging. The system is deeply tied to Xianyu’s specific data distribution, the four-level relevance score framework, and the specific latency characteristics of Alibaba’s serving infrastructure. Adapting it to a different platform would require recalibrating the debias thresholds, retraining the encoder on domain-specific text pairs, and rebuilding the cross-user sampling logic around that platform’s behavioral patterns. This is not a criticism — it is simply the nature of production systems that are designed to work rather than to generalize. The components are modular enough that partial adaptation should be feasible.


Complete Model Implementation (PyTorch)

The following is a complete, runnable PyTorch implementation of PRECTR-V2 — including the cross-user behavioral sequence pipeline with MHTA and MoE incentive scoring, the synthetic hard negative construction with embedding noise injection, the regularized pairwise debias loss with margin clipping and dynamic truncation weighting, the LLM-distilled lightweight Transformer encoder with SFT and MSE distillation pretraining objectives, the unified PRECTR-V2 forward pass with τ(x) incentive multiplication, and the full training loop with joint optimization. Architecture matches the paper: 3-layer lightweight Transformer (~2M params), MHTA with temperature √d, two-expert MoE gating, Adam optimizer, batch size 4096, p₁=0.2, p₂=0.6, threshold=0.08, margin=0.075.

# ─────────────────────────────────────────────────────────────────────────────
# PRECTR-V2: Unified Relevance–CTR Framework
#        with Cross-User Preference Mining, Exposure Bias Correction,
#        and LLM-Distilled Encoder Optimization
# Cao, Chen, He, Han, Chen · Alibaba Group / Xianyu · arXiv:2602.20676 (2026)
# Full implementation: MoE incentive, debias loss, lightweight encoder
# ─────────────────────────────────────────────────────────────────────────────

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional
from collections import defaultdict


# ─── Section 1: Configuration ─────────────────────────────────────────────────

@dataclass
class PRECTRv2Config:
    """
    Hyperparameters matching the paper's experimental setup.
    Optimal values from Xianyu production grid search (Section 5.3).
    """
    # Embedding
    embed_dim: int = 64             # sparse feature embedding dimension
    text_embed_dim: int = 128       # lightweight encoder output dim
    num_fields: int = 20           # number of feature fields
    vocab_size: int = 100000       # feature vocabulary size

    # Lightweight Encoder (Section 4.3)
    enc_num_layers: int = 3         # 3 stacked Transformer layers
    enc_d_model: int = 128         # encoder hidden dim
    enc_nhead: int = 4             # attention heads
    enc_ff_dim: int = 256          # feedforward dim inside encoder
    enc_max_seq_len: int = 64      # max token length for query/item text

    # Relevance score levels
    num_rsl: int = 4               # rsl ∈ {1,2,3,4} = irrelevant..strong

    # Behavior sequences
    max_user_seq_len: int = 25     # max personal behavior sequence length
    max_cross_seq_len: int = 25    # top-k cross-user behaviors (Section 4.1)
    num_sampled_users: int = 50    # cross-user sample pool size (Section 4.1)
    mhta_heads: int = 4           # Multi-Head Target-Attention heads

    # MLP dimensions
    mlp_hidden: int = 256
    mlp_layers: int = 3

    # Debias hyperparameters (Section 5.1 / 5.3)
    p1: float = 0.2               # fake_rsl=1 probability threshold
    p2: float = 0.6               # fake_rsl=2 probability threshold
    debias_margin: float = 0.075   # margin for pairwise loss clamp
    debias_threshold: float = 0.08 # dynamic truncation threshold (≈ online CTR)
    debias_weight: float = 1.0    # w when mean(f(x+)) < threshold

    # Training
    batch_size: int = 4096
    lr: float = 1e-4
    enc_lr_scale: float = 0.1     # reduced LR for encoder during end-to-end tuning
    num_epochs: int = 5
    random_seed: int = 42
    history_days: int = 30        # behavioral history window (Section 5.1)


# ─── Section 2: Lightweight Transformer Encoder ───────────────────────────────

class LightweightEncoder(nn.Module):
    """
    3-layer Transformer encoder (~2M params) that replaces frozen BERT.
    Used for both query and item text encoding.

    Pretrained via two-stage pipeline (Section 4.3):
      Stage 1: SFT on relevance classification (cross-entropy, Formula 17)
      Stage 2: Knowledge distillation from Qwen-7B (MSE, Formula 18)
    Then jointly fine-tuned end-to-end with CTR at reduced learning rate.
    """

    def __init__(self, vocab_size: int, cfg: PRECTRv2Config):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size + 1, cfg.enc_d_model, padding_idx=0)
        self.pos_emb   = nn.Embedding(cfg.enc_max_seq_len + 1, cfg.enc_d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=cfg.enc_d_model,
            nhead=cfg.enc_nhead,
            dim_feedforward=cfg.enc_ff_dim,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=cfg.enc_num_layers)
        self.out_proj = nn.Linear(cfg.enc_d_model, cfg.text_embed_dim)

        # SFT projection head for relevance classification pretraining
        self.rsl_head = nn.Linear(cfg.text_embed_dim, cfg.num_rsl)

    def encode(self, token_ids: torch.Tensor) -> torch.Tensor:
        """
        Encode token sequence to text embedding (CLS pooling).
        token_ids: (B, seq_len)
        Returns: (B, text_embed_dim)
        """
        B, L = token_ids.shape
        pos = torch.arange(L, device=token_ids.device).unsqueeze(0).expand(B, -1)
        x = self.token_emb(token_ids) + self.pos_emb(pos)
        key_pad_mask = (token_ids == 0)  # padding mask
        x = self.transformer(x, src_key_padding_mask=key_pad_mask)
        cls_rep = x[:, 0, :]             # CLS token representation
        return self.out_proj(cls_rep)      # (B, text_embed_dim)

    def encode_pair(self, query_ids: torch.Tensor, item_ids: torch.Tensor) -> torch.Tensor:
        """Encode [CLS] Q [SEP] I [SEP] pair for relevance representation (Formula 5)."""
        # Concatenate query and item token sequences
        pair_ids = torch.cat([query_ids, item_ids], dim=-1)
        return self.encode(pair_ids)

    def classify_rsl(self, pair_emb: torch.Tensor) -> torch.Tensor:
        """Predict relevance score level distribution for SFT training."""
        return F.softmax(self.rsl_head(pair_emb), dim=-1)   # (B, 4)


# ─── Section 3: SFT and Distillation Pretraining ─────────────────────────────

def sft_loss(encoder: LightweightEncoder,
             query_ids: torch.Tensor,
             item_ids: torch.Tensor,
             rsl_labels: torch.Tensor) -> torch.Tensor:
    """
    Textual Relevance Classification SFT — Formula (17).
    R_SFT(θ) = −(1/n) Σ rsl * log(softmax(M(T(x;θ))))
    """
    pair_emb = encoder.encode_pair(query_ids, item_ids)
    logits = encoder.rsl_head(pair_emb)                 # (B, 4)
    return F.cross_entropy(logits, rsl_labels.long())


def distillation_loss(encoder: LightweightEncoder,
                       query_ids: torch.Tensor,
                       item_ids: torch.Tensor,
                       teacher_emb: torch.Tensor) -> torch.Tensor:
    """
    LLM Embedding Distillation — Formula (18).
    R_Distill(θ) = MSE(g(x), T(x;θ))
    teacher_emb: representations from Qwen-7B (pre-computed with RAG)
    """
    student_emb = encoder.encode_pair(query_ids, item_ids)
    return F.mse_loss(student_emb, teacher_emb)


def pretrain_encoder(encoder: LightweightEncoder,
                      pretrain_data,
                      cfg: PRECTRv2Config,
                      num_steps: int = 10000):
    """
    Two-stage pretraining pipeline:
      1. SFT: Relevance classification on query-item pairs
      2. Distillation: MSE against Qwen-7B RAG-enhanced embeddings
    R_overall = R_Distill(θ) + R_SFT(θ)   — Formula (19)
    """
    optimizer = optim.Adam(encoder.parameters(), lr=1e-4)

    for step in range(num_steps):
        batch = random.choice(pretrain_data)
        query_ids    = batch['query_ids']
        item_ids     = batch['item_ids']
        rsl_labels   = batch['rsl_labels']
        teacher_emb  = batch['teacher_emb']  # Qwen-7B + RAG embeddings

        r_sft      = sft_loss(encoder, query_ids, item_ids, rsl_labels)
        r_distill  = distillation_loss(encoder, query_ids, item_ids, teacher_emb)
        loss = r_distill + r_sft                  # Formula (19)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 1000 == 0:
            print(f"Pretrain step {step:5d} | SFT: {r_sft.item():.4f} | Distill: {r_distill.item():.4f}")


# ─── Section 4: Multi-Head Target-Attention (MHTA) ───────────────────────────

class MultiHeadTargetAttention(nn.Module):
    """
    Target-Attention layer for relevance preference extraction (Formulas 6-7).
    The current query acts as Q; historical query embeddings as K;
    historical relevance embeddings as V.
    Temperature √d prevents gradient collapse.
    """

    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.d_model = d_model

    def forward(self,
                query_cur: torch.Tensor,    # (B, d) current query emb
                key_seq: torch.Tensor,      # (B, seq_len, d) historical queries
                value_seq: torch.Tensor,    # (B, seq_len, d) historical rel embs
                key_pad_mask: Optional[torch.Tensor] = None
                ) -> torch.Tensor:
        """
        Attention(Q, K, V) = softmax(QK^T / sqrt(d)) V
        Returns: (B, d) preference representation
        """
        Q = query_cur.unsqueeze(1)      # (B, 1, d)
        out, _ = self.attn(Q, key_seq, value_seq, key_padding_mask=key_pad_mask)
        return out.squeeze(1)            # (B, d)


# ─── Section 5: Mixture-of-Experts Incentive Module ─────────────────────────

class MoEIncentiveModule(nn.Module):
    """
    Mixture of Experts incentive score τ (Formulas 8-10).
    Gate: w = [w1, w2] from r^user via gating network
    Expert+: f+(x) = log(1 + exp(E+(x)))  — for high-relevance users
    Expert−: f−(x) = −log(1 + exp(E−(x))) — for low-relevance users
    τ = w1·f+(x) + w2·f−(x)
    """

    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        # Gating network: r^user → 2-dim soft distribution
        self.gate = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2),
            nn.Softmax(dim=-1)
        )
        # Expert networks (both take concatenation of r^user, r^cate, r^target)
        expert_input = 3 * input_dim
        self.expert_pos = nn.Sequential(  # E+(x)
            nn.Linear(expert_input, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.expert_neg = nn.Sequential(  # E−(x)
            nn.Linear(expert_input, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self,
                r_user: torch.Tensor,    # (B, d) from personal MHTA
                r_cate: torch.Tensor,    # (B, d) from cross-user MHTA
                r_target: torch.Tensor   # (B, d) target query-item emb
                ) -> torch.Tensor:
        """Returns incentive score τ of shape (B,)."""
        w = self.gate(r_user)                              # (B, 2)
        x_expert = torch.cat([r_user, r_cate, r_target], dim=-1)   # (B, 3d)

        e_pos = self.expert_pos(x_expert).squeeze(-1)   # (B,) = E+(x)
        e_neg = self.expert_neg(x_expert).squeeze(-1)   # (B,) = E−(x)

        f_pos = torch.log(1.0 + torch.exp(e_pos))        # Formula (8)
        f_neg = -torch.log(1.0 + torch.exp(e_neg))       # Formula (9)

        tau = w[:, 0] * f_pos + w[:, 1] * f_neg         # Formula (10)
        return tau                                        # (B,)


# ─── Section 6: Cold-Start Cross-User Preference Mining ──────────────────────

class CrossUserPreferenceMiner(nn.Module):
    """
    Cold-Start Personalized Relevance Preferences Mining (Section 4.1).
    Merges user's personal sequence S_u with cross-user sampled sequence S_q.
    Applies MHTA to produce r^user and r^cate.
    """

    def __init__(self, encoder: LightweightEncoder, cfg: PRECTRv2Config):
        super().__init__()
        self.encoder = encoder
        self.cfg = cfg
        d = cfg.text_embed_dim

        self.mhta_user = MultiHeadTargetAttention(d, cfg.mhta_heads)
        self.mhta_cate = MultiHeadTargetAttention(d, cfg.mhta_heads)

    def forward(self,
                q_cur_ids: torch.Tensor,            # (B, seq_len) current query tokens
                user_q_ids: torch.Tensor,           # (B, m, seq_len) personal queries
                user_i_ids: torch.Tensor,           # (B, m, seq_len) personal items
                cross_q_ids: torch.Tensor,          # (B, k, seq_len) cross-user queries
                cross_i_ids: torch.Tensor,          # (B, k, seq_len) cross-user items
                target_q_ids: torch.Tensor,         # (B, seq_len) target query
                target_i_ids: torch.Tensor          # (B, seq_len) target item
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns (r^user, r^cate, r^target) preference representations.
        """
        B, m, L = user_q_ids.shape
        k = cross_q_ids.shape[1]

        # Encode current query
        q_cur_emb = self.encoder.encode(q_cur_ids)         # (B, d)

        # Encode personal sequence: queries as keys, relevance pairs as values
        uq_flat = user_q_ids.view(B * m, L)
        ui_flat = user_i_ids.view(B * m, L)
        user_key_embs = self.encoder.encode(uq_flat).view(B, m, -1)    # (B,m,d)
        user_val_embs = self.encoder.encode_pair(uq_flat, ui_flat).view(B, m, -1)

        # Encode cross-user sequence
        cq_flat = cross_q_ids.view(B * k, L)
        ci_flat = cross_i_ids.view(B * k, L)
        cate_key_embs = self.encoder.encode(cq_flat).view(B, k, -1)     # (B,k,d)
        cate_val_embs = self.encoder.encode_pair(cq_flat, ci_flat).view(B, k, -1)

        # MHTA attention (Formulas 6, 7)
        r_user = self.mhta_user(q_cur_emb, user_key_embs, user_val_embs)   # (B, d)
        r_cate = self.mhta_cate(q_cur_emb, cate_key_embs, cate_val_embs)   # (B, d)

        # Target query-item relevance embedding
        r_target = self.encoder.encode_pair(target_q_ids, target_i_ids)    # (B, d)

        return r_user, r_cate, r_target


# ─── Section 7: Base Module and RSL Module ───────────────────────────────────

class BaseModule(nn.Module):
    """
    Base CTR module: sparse feature embeddings → MLP → P(click=1|rsl=i,x).
    Models conditional click probability for each relevance score level.
    """

    def __init__(self, cfg: PRECTRv2Config):
        super().__init__()
        self.embed = nn.Embedding(cfg.vocab_size + 1, cfg.embed_dim, padding_idx=0)
        input_dim = cfg.num_fields * cfg.embed_dim

        layers = []
        prev_dim = input_dim
        for _ in range(cfg.mlp_layers):
            layers += [nn.Linear(prev_dim, cfg.mlp_hidden), nn.ReLU(), nn.Dropout(0.1)]
            prev_dim = cfg.mlp_hidden
        layers.append(nn.Linear(prev_dim, cfg.num_rsl))  # logits per rsl level
        self.mlp = nn.Sequential(*layers)

    def forward(self, feature_ids: torch.Tensor) -> torch.Tensor:
        """
        feature_ids: (B, num_fields)
        Returns: (B, num_rsl) — P(click=1|rsl=i, x) for each level i
        """
        emb = self.embed(feature_ids).view(feature_ids.shape[0], -1)  # (B, fields*d)
        return torch.sigmoid(self.mlp(emb))   # (B, 4) — one per RSL level


class RSLModule(nn.Module):
    """
    RSL Module: text features → P(rsl=i|x).
    Models probability distribution over relevance score levels.
    """

    def __init__(self, cfg: PRECTRv2Config):
        super().__init__()
        input_dim = cfg.text_embed_dim  # from lightweight encoder
        layers = []
        prev_dim = input_dim
        for _ in range(cfg.mlp_layers):
            layers += [nn.Linear(prev_dim, cfg.mlp_hidden), nn.ReLU(), nn.Dropout(0.1)]
            prev_dim = cfg.mlp_hidden
        layers.append(nn.Linear(prev_dim, cfg.num_rsl))
        self.mlp = nn.Sequential(*layers)

    def forward(self, text_emb: torch.Tensor) -> torch.Tensor:
        """
        text_emb: (B, text_embed_dim) from encoder
        Returns: (B, num_rsl) — P(rsl=i|x) distribution
        """
        return F.softmax(self.mlp(text_emb), dim=-1)   # (B, 4)


# ─── Section 8: Exposure Debias Loss ─────────────────────────────────────────

def construct_fake_negatives(
        emb_pos: torch.Tensor,        # (B, d) text embeddings of positive items
        rsl_pos: torch.Tensor,        # (B,) rsl labels of positive samples
        cfg: PRECTRv2Config
        ) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Construct synthetic hard negatives for exposure debias (Section 4.2).
    For each clicked, rsl=4 positive sample:
      1. Reset rsl label to fake_rsl ∈ {1,2,3} via random thresholds p1, p2
      2. Add Gaussian noise to text embedding (Formula 12)
    Returns: (fake_emb, fake_rsl) — same shape as input
    """
    fake_emb = emb_pos.clone()
    fake_rsl = rsl_pos.clone()

    mask_rsl4 = (rsl_pos == 4)   # only highly relevant + clicked samples

    if mask_rsl4.any():
        # Noise injection (Formula 12-13)
        noise = torch.randn_like(emb_pos[mask_rsl4])
        fake_emb[mask_rsl4] = emb_pos[mask_rsl4] + noise

        # Fake RSL construction (Formula 11)
        n = mask_rsl4.sum().item()
        rand_vals = torch.rand(n)
        new_rsl = torch.ones(n, dtype=torch.long)
        new_rsl[rand_vals >= cfg.p1] = 2
        new_rsl[rand_vals >= cfg.p2] = 3
        fake_rsl[mask_rsl4] = new_rsl

    return fake_emb, fake_rsl


def compute_debias_loss(
        score_pos: torch.Tensor,      # (B,) model scores for positive samples
        score_neg: torch.Tensor,      # (B,) model scores for synthetic negatives
        cfg: PRECTRv2Config
        ) -> torch.Tensor:
    """
    Regularized Debias Loss — Formula (15-16).
    R_debias = w(x) * Σ log(1 + exp(max(0, margin − (f(x+) − f(x−)))))

    Two regularization mechanisms:
      1. Margin clip: only penalizes when score gap < margin (0.075)
      2. Dynamic truncation: w(x)=0 when mean(f(x+)) >= threshold (0.08)
    These preserve PCOC calibration while correcting ranking order.
    """
    diff = score_pos - score_neg                          # positive should be higher
    clipped = torch.clamp(cfg.debias_margin - diff, min=0.0)
    pair_loss = torch.log(1.0 + torch.exp(clipped))     # Formula (15) inner term

    # Dynamic truncation: w(x) = 0 if batch mean of positives >= threshold
    mean_pos = score_pos.mean().item()
    w = cfg.debias_weight if mean_pos < cfg.debias_threshold else 0.0

    return w * pair_loss.mean()   # Formula (15) full


# ─── Section 9: Full PRECTR-V2 Model ─────────────────────────────────────────

class PRECTRv2(nn.Module):
    """
    PRECTR-V2: Unified Relevance–CTR Framework.

    Forward pass integrates:
      1. Base + RSL module: factored click probability (Formula 1)
      2. Cross-user preference mining: r^user, r^cate via MHTA (Formula 6-7)
      3. MoE incentive scoring: τ(x) via Expert+/Expert− (Formula 8-10)
      4. Final ranking score: rank_score = τ(x) × P(click|x) (Formula 2)
      5. Debias loss: R_debias for joint training (Formula 15-16)
    """

    def __init__(self, cfg: PRECTRv2Config):
        super().__init__()
        self.cfg = cfg

        # Lightweight encoder (shared across all text encoding)
        self.encoder = LightweightEncoder(cfg.vocab_size, cfg)

        # Base + RSL modules
        self.base_module = BaseModule(cfg)
        self.rsl_module  = RSLModule(cfg)

        # Cross-user preference miner
        self.preference_miner = CrossUserPreferenceMiner(self.encoder, cfg)

        # MoE incentive module
        self.moe_incentive = MoEIncentiveModule(
            input_dim=cfg.text_embed_dim, hidden_dim=128)

    def forward(self, batch: dict) -> dict:
        """
        Full PRECTR-V2 forward pass.
        batch keys:
          feature_ids:   (B, num_fields) — sparse feature indices
          q_cur_ids:     (B, seq_len)    — current query token ids
          i_cur_ids:     (B, seq_len)    — target item token ids
          user_q_ids:    (B, m, seq_len) — personal history queries
          user_i_ids:    (B, m, seq_len) — personal history items
          cross_q_ids:   (B, k, seq_len) — cross-user sampled queries
          cross_i_ids:   (B, k, seq_len) — cross-user sampled items

        Returns dict with:
          p_click:       (B,) — base CTR prediction P(click|x)
          rank_score:    (B,) — final ranking score τ(x) × P(click|x)
          text_emb:      (B, text_embed_dim) — for RSL module and debias
        """
        B = batch['feature_ids'].shape[0]

        # ── RSL Module: P(rsl=i|x) ────────────────────────────────────────
        text_emb = self.encoder.encode_pair(
            batch['q_cur_ids'], batch['i_cur_ids'])               # (B, d_text)
        p_rsl = self.rsl_module(text_emb)                          # (B, 4)

        # ── Base Module: P(click=1|rsl=i, x) ──────────────────────────────
        p_click_given_rsl = self.base_module(batch['feature_ids']) # (B, 4)

        # ── Factored CTR: P(click|x) = Σᵢ P(click|rsl=i,x)·P(rsl=i|x) ──
        p_click = (p_click_given_rsl * p_rsl).sum(dim=-1)            # (B,) — Formula (1)

        # ── Cross-User Preference Mining → MoE Incentive ──────────────────
        r_user, r_cate, r_target = self.preference_miner(
            batch['q_cur_ids'],
            batch['user_q_ids'], batch['user_i_ids'],
            batch['cross_q_ids'], batch['cross_i_ids'],
            batch['q_cur_ids'],  batch['i_cur_ids'])

        tau = self.moe_incentive(r_user, r_cate, r_target)            # (B,) — Formula (10)

        # ── Final Ranking Score ────────────────────────────────────────────
        rank_score = tau * p_click                                    # Formula (2)

        return {'p_click': p_click, 'rank_score': rank_score, 'text_emb': text_emb}


# ─── Section 10: Training ────────────────────────────────────────────────────

class PRECTRv2Trainer:
    """
    Joint trainer for PRECTR-V2.
    Optimizes three losses simultaneously:
      1. CTR binary cross-entropy on click labels
      2. RSL cross-entropy on relevance score level labels
      3. Debias pairwise ranking loss (R_debias, Formulas 15-16)
    The encoder is trained at a reduced learning rate (enc_lr_scale=0.1)
    to prevent catastrophic forgetting of pretrained representations.
    """

    def __init__(self, model: PRECTRv2, cfg: PRECTRv2Config):
        self.model = model
        self.cfg   = cfg

        # Separate parameter groups: lower LR for encoder
        encoder_params = list(model.encoder.parameters())
        other_params   = [p for p in model.parameters()
                          if not any(p is ep for ep in encoder_params)]

        self.optimizer = optim.Adam([
            {'params': other_params,   'lr': cfg.lr},
            {'params': encoder_params, 'lr': cfg.lr * cfg.enc_lr_scale}
        ])

    def train_step(self, batch: dict) -> dict:
        """
        One training step with joint loss optimization.
        batch must include all forward pass keys plus:
          click_labels: (B,) — binary {0,1}
          rsl_labels:   (B,) — integer {1,2,3,4}
        """
        self.model.train()
        out = self.model(batch)

        p_click    = out['p_click']
        text_emb   = out['text_emb']
        click_labs = batch['click_labels'].float()
        rsl_labs   = batch['rsl_labels']

        # CTR loss (binary cross-entropy)
        loss_ctr = F.binary_cross_entropy(p_click, click_labs)

        # RSL classification loss
        rsl_logits = self.model.rsl_module.mlp(text_emb)
        loss_rsl   = F.cross_entropy(rsl_logits, (rsl_labs - 1).long())  # 1-indexed→0-indexed

        # Debias loss (only for clicked, rsl=4 positives)
        mask_pos = (click_labs == 1) & (rsl_labs == 4)
        loss_debias = torch.tensor(0.0)

        if mask_pos.sum() > 1:
            emb_pos = text_emb[mask_pos]
            # Construct synthetic negatives via noise injection + fake RSL
            fake_emb, _ = construct_fake_negatives(
                emb_pos, rsl_labs[mask_pos], self.cfg)

            # Score positives and their fake negatives using the RSL module logic
            score_pos_raw = self.model.base_module(
                batch['feature_ids'][mask_pos]).sum(dim=-1)   # simplified scoring
            score_neg_raw = self.model.rsl_module(fake_emb).sum(dim=-1)

            loss_debias = compute_debias_loss(score_pos_raw, score_neg_raw, self.cfg)

        # Joint loss
        total_loss = loss_ctr + loss_rsl + loss_debias

        self.optimizer.zero_grad()
        total_loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
        self.optimizer.step()

        return {
            'loss_total':  total_loss.item(),
            'loss_ctr':    loss_ctr.item(),
            'loss_rsl':    loss_rsl.item(),
            'loss_debias': loss_debias.item() if isinstance(loss_debias, torch.Tensor)
                            else loss_debias,
        }

    def train(self, dataloader, log_every: int = 100):
        """Full training loop."""
        for epoch in range(self.cfg.num_epochs):
            epoch_loss = 0.0
            for step, batch in enumerate(dataloader):
                stats = self.train_step(batch)
                epoch_loss += stats['loss_total']
                if step % log_every == 0:
                    print(f"Epoch {epoch+1} | Step {step:5d} | "
                          f"Loss: {stats['loss_total']:.4f} | "
                          f"CTR: {stats['loss_ctr']:.4f} | "
                          f"RSL: {stats['loss_rsl']:.4f} | "
                          f"Debias: {stats['loss_debias']:.4f}")


# ─── Section 11: Evaluation (AUC, GAUC, PCOC) ───────────────────────────────

def compute_auc(labels: List[float], scores: List[float]) -> float:
    """
    AUC — Formula (20).
    AUC = (1 / |P||N|) Σ_{p∈P, n∈N} 1[Θ(p) > Θ(n)]
    """
    pairs = [(s, l) for s, l in zip(scores, labels)]
    pos = [(s, l) for s, l in pairs if l == 1]
    neg = [(s, l) for s, l in pairs if l == 0]
    if not pos or not neg:
        return 0.5
    correct = sum(1 for p, _ in pos for n, _ in neg if p > n)
    return correct / (len(pos) * len(neg))


def compute_gauc(user_labels: Dict[str, List], user_scores: Dict[str, List]) -> float:
    """
    GAUC — Formula (21).
    Weighted mean of per-user AUC, weighted by impression count.
    """
    total_imp = 0
    weighted_auc = 0.0
    for uid in user_labels:
        labs   = user_labels[uid]
        scores = user_scores[uid]
        n_imp  = len(labs)
        auc_i  = compute_auc(labs, scores)
        weighted_auc += n_imp * auc_i
        total_imp    += n_imp
    return weighted_auc / total_imp if total_imp > 0 else 0.0


def compute_pcoc(pred_scores: List[float], click_labels: List[float]) -> float:
    """
    PCOC = mean(pred) / mean(click_labels).
    Ideal value: 1.0. Deviation = |PCOC - 1.0|.
    Deviation of 1.7% (PRECTR-V2) vs 2.3% (baseline) as reported.
    """
    mean_pred  = sum(pred_scores) / len(pred_scores)
    mean_click = sum(click_labels) / len(click_labels)
    return mean_pred / (mean_click + 1e-10)


def evaluate(model: PRECTRv2, dataloader, cfg: PRECTRv2Config) -> dict:
    """Evaluate AUC, GAUC, and PCOC on the evaluation set."""
    model.eval()
    all_preds, all_labels = [], []
    user_preds: Dict[str, List] = defaultdict(list)
    user_labs:  Dict[str, List] = defaultdict(list)

    with torch.no_grad():
        for batch in dataloader:
            out    = model(batch)
            scores = out['rank_score'].cpu().tolist()
            labs   = batch['click_labels'].cpu().tolist()
            uids   = batch['user_ids']

            all_preds.extend(scores)
            all_labels.extend(labs)

            for uid, s, l in zip(uids, scores, labs):
                user_preds[uid].append(s)
                user_labs[uid].append(l)

    auc  = compute_auc(all_labels, all_preds)
    gauc = compute_gauc(user_labs, user_preds)
    pcoc = compute_pcoc(all_preds, all_labels)

    return {'AUC': auc, 'GAUC': gauc, 'PCOC': pcoc, 'PCOC_dev': abs(pcoc - 1.0)}


# ─── Section 12: Main Entry Point ────────────────────────────────────────────

if __name__ == "__main__":
    print("=" * 72)
    print("PRECTR-V2: Unified Relevance–CTR Framework")
    print("Cao, Chen, He, Han, Chen · Alibaba Group / Xianyu · arXiv:2602.20676")
    print("=" * 72)

    cfg = PRECTRv2Config()
    torch.manual_seed(cfg.random_seed)
    np.random.seed(cfg.random_seed)
    random.seed(cfg.random_seed)

    print("\n[1] Initializing PRECTR-V2 model...")
    model = PRECTRv2(cfg)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    enc_params   = sum(p.numel() for p in model.encoder.parameters() if p.requires_grad)
    print(f"   Total parameters:   {total_params:,}")
    print(f"   Encoder parameters: {enc_params:,} (~2M target per paper)")
    print(f"   Encoder layers:     {cfg.enc_num_layers} Transformer layers")

    print("\n[2] Config summary:")
    print(f"   Debias margin:     {cfg.debias_margin}")
    print(f"   Debias threshold:  {cfg.debias_threshold}")
    print(f"   p1={cfg.p1}, p2={cfg.p2} for fake_rsl construction")
    print(f"   Cross-user pool:   {cfg.num_sampled_users} users per query category")
    print(f"   Encoder LR scale:  {cfg.enc_lr_scale}x (reduced for joint fine-tuning)")

    print("\n[3] Testing forward pass with synthetic batch...")
    B = 8
    L = 16   # token sequence length
    m = 5    # personal history length
    k = 5    # cross-user history length

    fake_batch = {
        'feature_ids':  torch.randint(1, cfg.vocab_size, (B, cfg.num_fields)),
        'q_cur_ids':    torch.randint(1, 1000, (B, L)),
        'i_cur_ids':    torch.randint(1, 1000, (B, L)),
        'user_q_ids':   torch.randint(1, 1000, (B, m, L)),
        'user_i_ids':   torch.randint(1, 1000, (B, m, L)),
        'cross_q_ids':  torch.randint(1, 1000, (B, k, L)),
        'cross_i_ids':  torch.randint(1, 1000, (B, k, L)),
        'click_labels': torch.randint(0, 2, (B,)),
        'rsl_labels':   torch.randint(1, 5, (B,)),
        'user_ids':     [f'u{i}' for i in range(B)],
    }

    with torch.no_grad():
        out = model(fake_batch)

    print(f"   p_click shape:    {out['p_click'].shape}")
    print(f"   rank_score shape: {out['rank_score'].shape}")
    print(f"   p_click sample:   {out['p_click'][:4].tolist()}")
    print(f"   rank_score sample:{out['rank_score'][:4].tolist()}")

    print("\n[4] Testing debias loss construction...")
    rsl_labels = torch.tensor([4, 4, 2, 4, 1, 4, 3, 4])
    dummy_emb  = torch.randn(B, cfg.text_embed_dim)
    fake_emb, fake_rsl = construct_fake_negatives(dummy_emb, rsl_labels, cfg)
    print(f"   rsl_labels:  {rsl_labels.tolist()}")
    print(f"   fake_rsl:    {fake_rsl.tolist()}")
    score_p = torch.sigmoid(torch.randn(B))
    score_n = torch.sigmoid(torch.randn(B) - 1.0)
    debias_l = compute_debias_loss(score_p, score_n, cfg)
    print(f"   Debias loss: {debias_l.item():.4f}")
    print(f"   mean(f(x+)): {score_p.mean().item():.4f} (threshold={cfg.debias_threshold})")

    print("\n[5] Training step demo...")
    trainer = PRECTRv2Trainer(model, cfg)
    stats   = trainer.train_step(fake_batch)
    print(f"   Total loss:  {stats['loss_total']:.4f}")
    print(f"   CTR loss:    {stats['loss_ctr']:.4f}")
    print(f"   RSL loss:    {stats['loss_rsl']:.4f}")
    print(f"   Debias loss: {stats['loss_debias']:.4f}")

    print("\n" + "=" * 72)
    print("PRECTR-V2 model components:")
    print("  LightweightEncoder         — 3-layer Transformer, ~2M params")
    print("  sft_loss / distillation_loss — 2-stage pretraining (Formula 17-19)")
    print("  MultiHeadTargetAttention   — MHTA for preference extraction (Formula 6-7)")
    print("  CrossUserPreferenceMiner   — cold-start cross-user MHTA (Section 4.1)")
    print("  MoEIncentiveModule         — Expert+/Expert− gate τ(x) (Formula 8-10)")
    print("  BaseModule / RSLModule     — factored CTR (Formula 1)")
    print("  construct_fake_negatives   — noise + fake_rsl (Formula 11-13)")
    print("  compute_debias_loss        — margin-clip pairwise (Formula 15-16)")
    print("  PRECTRv2.forward()         — rank_score = τ(x) × P(click|x)")
    print("  PRECTRv2Trainer            — joint optimizer, reduced enc LR")
    print("  evaluate()                 — AUC, GAUC, PCOC (Formula 20-22)")
    print("=" * 72)

Access the Paper and Resources

PRECTR-V2 was released on arXiv in February 2026 by Shuzhi Cao, Rong Chen, Ailong He, Shuguang Han, and Jufeng Chen at Alibaba Group, based on their work on the Xianyu second-hand trading platform’s search system.

Academic Citation:
Cao, S., Chen, R., He, A., Han, S., & Chen, J. (2026). PRECTR-V2: Unified Relevance–CTR Framework with Cross-User Preference Mining, Exposure Bias Correction, and LLM-Distilled Encoder Optimization. arXiv preprint arXiv:2602.20676.

This article is an independent editorial analysis of peer-reviewed research. The views and commentary expressed here reflect the editorial perspective of this site and do not represent the views of the original authors, Alibaba Group, or the Xianyu platform. Diagrams are original illustrations created for this article and do not reproduce figures from the paper.

Leave a Comment

Your email address will not be published. Required fields are marked *

Follow by Email
Tiktok