PEGN: How Persistent Homology Breaks the WL Barrier in Graph Neural Networks | AI Trend Blend

Loops, Cycles, and the Topology GNNs Cannot See: How PEGN Breaks the Weisfeiler-Lehman Ceiling

A multi-institution team spanning Peking University, UC San Diego, Stony Brook University, and Weill Cornell Medicine built PEGN — a framework that injects localized topological features from extended persistent homology into graph neural networks, pushing expressiveness beyond the 3-WL barrier on node classification and link prediction.

Persistent Homology PEGN Graph Neural Networks Topological Data Analysis Extended Persistent Homology Node Classification Link Prediction Weisfeiler-Lehman Test Persistence Image Ollivier-Ricci Curvature
PEGN architecture diagram showing extended persistent homology pipeline feeding localized topological features into graph neural network for node classification and link prediction
🔵 PEGN — Persistent Homology Enhanced Graph Networks

Graph neural networks have a dirty secret: two graphs that look completely different to a human eye can be, from the GNN’s perspective, utterly indistinguishable. The problem runs deeper than implementation detail — it is a fundamental ceiling imposed by the Weisfeiler-Lehman isomorphism test, the combinatorial algorithm whose power standard message-passing GNNs are provably bounded by. A team from Peking University, UCSD, Stony Brook University, and Weill Cornell Medicine decided to attack this ceiling from an unusual direction — not by redesigning the message-passing mechanism, but by augmenting it with topological information that no local neighborhood aggregation can ever recover on its own.


The Loop-Counting Problem Standard GNNs Cannot Solve

Here is the failure mode, stated as concretely as possible. Draw a graph with two connected squares sharing a common edge — five nodes forming a figure that contains two distinct loops. Now draw a different graph: five nodes, same degree sequence, but arranged so only one loop exists. Feed both to any standard message-passing GNN. Ask it to distinguish them. It cannot. Every node in both graphs will produce identical embeddings after any number of propagation rounds, because the 1-WL algorithm — which colors nodes based on their aggregated neighbor colors — treats these two graphs identically.

Adding node degree as a feature, which GCN does by default, does not help either: all nodes have the same degree in both examples. What would actually work is counting loops. Graph (a) has two; graph (b) has one. A model that knows how many cycles surround a node can immediately tell them apart. The challenge is extracting that loop information in a form that is stable, computable, and compatible with modern deep learning pipelines.

That is precisely what persistent homology offers. And the paper by Zuoyu Yan, Qi Zhao, Ze Ye, Tengfei Ma, Liangcai Gao, Zhi Tang, Yusu Wang, and Chao Chen — published in the Journal of Machine Learning Research in 2025 — is the first work to do it in a genuinely localized way, generating per-node and per-edge topological features rather than a single global summary of the whole graph.

Core Insight

Standard message-passing GNNs are provably bounded by the 1-WL test and cannot count cycles, measure loop lifespans, or capture the connectivity richness around a node. PEGN sidesteps this by computing extended persistence diagrams on the local neighborhood of each node and edge — encoding exactly the topological information that 1-WL permanently discards.


Persistent Homology in Plain Terms

Persistent homology is a tool from algebraic topology that tracks how topological structures — connected components, loops, voids — are born and destroyed as you gradually build up a shape by sweeping a threshold. Assign each node in a graph a scalar value (called a filter function, like node degree or graph-geodesic distance). Start with an empty graph. Slowly lower the threshold, adding nodes and then edges as their filter values are reached. A connected component is born when a new isolated node appears; it dies when it merges with an older component. A loop is born when an edge closes a cycle; in ordinary persistent homology, that loop might never technically die.

That last point is the reason the paper uses extended persistent homology (EPH) rather than the classical version. EPH appends a descending filtration after the ascending one — sweeping back down from the maximum filter value to the minimum. This extended sequence guarantees that every loop is eventually killed, giving it a well-defined birth-death pair called a persistence point \((b, d)\). The lifetime \(|b – d|\) measures how structurally significant the loop is. All these persistence points together form an Extended Persistence Diagram (EPD) — a 2D multiset of points, one per topological feature, that constitutes a mathematically stable fingerprint of the graph’s topology.

Ascending Filter (Edges) $$f(uv) = \max(f(u),\, f(v))$$
Descending Filter (Edges) $$f(uv) = \min(f(u),\, f(v))$$

The stability result — that small perturbations to the filter function cause only small changes to the persistence diagram — is what makes EPDs worth building a machine learning system around. You get both expressiveness and robustness from the same mathematical object.

PEGN architecture diagram showing extended persistent homology pipeline feeding localized topological features into graph neural network for node classification and link prediction
Figure 1: The extended persistence pipeline. A graph is filtered ascending then descending; topological structures (components in green/red, loops in blue) live between their birth and death threshold values. Each structure yields one persistence point in the EPD. The diagram is then converted to a fixed-size persistence image via 2D Gaussian smoothing for use in neural networks.

From Diagrams to Vectors: The Persistence Image

EPDs live in a variable-size space — different graphs produce diagrams with different numbers of points — which means they cannot be fed directly into a GNN’s fixed-size pipeline. The paper uses persistence images, a vectorisation technique introduced by Adams et al. (2017). The process applies a linear transformation \(T(x,y) = (x, y-x)\) to each persistence point, places a 2D Gaussian around each transformed point, weights it by a piecewise-linear function that suppresses points near the diagonal (those with short lifespans), and integrates over a pixel grid. The result is a fixed-size vector in \(\mathbb{R}^n\) — stable with respect to the Wasserstein distance on diagrams.

Persistence Surface $$\rho_D(z) = \sum_{u \in T(D)} \alpha(u)\,\phi_u(z), \quad \phi_u(z) = \frac{1}{2\pi\sigma^2} e^{-\|z-u\|^2/(2\sigma^2)}$$

The authors also propose a structural augmentation called persistence image plus (PI+), which concatenates the standard persistence image with explicit counts of level-\(j\) nodes, intra-layer edges, and inter-layer edges:

PI+ Formulation (Eq. 4) $$PI^+ = [PI;\; \{n_j \mid j \leq k\};\; \{n_{j,j} \mid j \leq k\};\; \{n_{j,(j+1)} \mid j < k\}]$$

This addition is not just cosmetic. As the theorems in Section 4.1 establish, the raw persistence image can struggle to convey certain intra-layer and inter-layer connection counts that the EPD fundamentally encodes. PI+ makes those counts explicit, improving representation power at negligible computational cost.


The Architecture: Vicinity Graphs, Filter Functions, and PEGN

Locality First: The Vicinity Graph

Prior work on persistent homology for graphs mostly computed a single EPD for the entire graph — a global summary. That does not translate into useful per-node or per-edge representations because two nodes in very different structural positions would share the same global diagram. PEGN is the first framework to make persistent homology local.

For a target node \(u\), the system extracts the \(k\)-hop subgraph \(G^k_u\) — all nodes within \(k\) hops and all edges between them. The filter function is then defined as the shortest-path distance from each node to \(u\) within this subgraph, using Ollivier-Ricci curvature as the graph metric. Ollivier-Ricci curvature measures the “average transportation cost” from the neighbors of one node to the neighbors of another, and has been shown to improve graph convolution both theoretically and empirically. Adding 1 to the curvature ensures all edge weights are positive.

For a target edge \((u, v)\) in the link prediction setting, the vicinity graph is the intersection of the \(k\)-hop neighborhoods of \(u\) and \(v\). The filter function for any node \(i\) in this intersection is the sum of its distances to the two target nodes: \(f(i) = d_o(i,u) + d_o(i,v)\). This pair-wise design is crucial — it captures the common structural context shared between the two nodes, which a node-wise feature would miss entirely.

Node Classification Branch

In the node classification branch, the persistence image \(PI(u)\) of node \(u\) is used to reweight the messages that \(u\) receives from its neighbors. For each directed message from neighbor \(v\) to \(u\), the edge weight is computed as:

Edge Weight (Eq. 6) $$\alpha^{k-1}_{uv} = \sigma\!\left(W^{k-1}_2\, [PI(u);\, PI(v)]\right)$$

where \(W^{k-1}_2\) is a two-layer MLP with a 50-unit hidden layer, and the two persistence images are concatenated before being passed through it. The message-passing update then becomes:

Message Passing (Eq. 5) $$h^k_u = \sigma_{k-1}\!\left(\sum_{v \in \mathcal{N}(u)} \alpha^{k-1}_{uv}\, W^{k-1}_1\, h^{k-1}_v\right)$$

The topology is not an add-on feature concatenated at the end. It directly shapes which parts of the neighborhood get listened to more carefully, integrating structural information into the propagation mechanism itself.

Link Prediction Branch

The link prediction architecture has two parallel branches. The upper branch runs a standard 2-layer GCN to produce node embeddings \(h_u\) and \(h_v\) for the target node pair. The lower branch computes the pair-wise persistence image \(PI(u,v)\) from the intersection vicinity graph. A modified Fermi-Dirac decoder then combines both streams:

Fermi-Dirac Decoder $$FD(u,v) = \text{MLP}\!\left((h_u – h_v);\, PI(u,v)\right), \quad p_{uv} = \frac{1}{e^{FD(u,v)-2} + 1}$$
PEGN architecture diagram showing extended persistent homology pipeline feeding localized topological features into graph neural network for node classification and link prediction
Figure 2: The pair of graphs that 3-WL cannot differentiate. (a) The Shrikhande Graph. (b) The 4*4 Rook Graph.

Theoretical Guarantees: Beating 3-WL and 4-WL

The theoretical analysis is one of the more satisfying parts of the paper. The authors do not just claim their features are “more expressive” in vague terms — they prove specific separation results against the k-WL hierarchy.

Theorem 1 says that with the shortest-path distance as the filter function, EPDs can differentiate graph pairs that 3-WL cannot separate. The proof uses the 4×4 Rook Graph and the Shrikhande Graph — a famous pair of strongly regular graphs that stumps 3-WL. In the Rook graph, every node’s 1-hop neighborhood induces two 3-cycles (so its 1st Betti number is 2). In the Shrikhande graph, every node’s 1-hop neighborhood induces one 6-cycle (Betti number 1). EPDs capture Betti numbers directly from the persistence point at coordinates \((1.5, 0.5)\), so the pair is instantly separated.

Theorem 2 pushes further: with the more sophisticated distance function used for link prediction, EPDs can differentiate pairs that even 4-WL cannot. The proof uses a Cai-Fürer-Immerman-style construction — a family of graphs specifically designed to fool high-order WL tests — and shows that the pair-wise filter function produces different 1st Betti numbers for the two graphs’ common-neighbor substructures.

Theorem 3 provides the honest upper bound: EPDs are less powerful than 4-WL in general. This is good epistemics — the paper is not claiming to have solved graph isomorphism, just to have made a meaningful improvement within a well-characterised range.

Theorem 4 addresses random regular graphs, where the WL test is known to fail on many pairs. It shows that EPDs can distinguish \(1 – o(n^{-1/2})\) pairs of \(n\)-sized \(r\)-regular graphs using at most \(K = \lfloor(1/2 + \epsilon)\log_2 n / \log(r-1)\rfloor\) hops. Almost all pairs of regular graphs that confuse WL are separated by EPDs. That is a strong practical guarantee for real-world graphs, which tend to have bounded and relatively uniform degree distributions.

“The localized topological feature captures the connectivity richness around a node or edge of interest — structural information that is hard for standard message-passing GNNs to recover, but can be recovered from EPDs by simple queries that can be easily implemented using an MLP.” — Yan, Zhao, Ye, Ma, Gao, Tang, Wang and Chen, JMLR (2025)

Results: Where the Topology Actually Helps

Node Classification

The benchmark covers seven datasets ranging from small sparse citation graphs (Cora, Citeseer at ~3000 nodes) to large dense co-purchase and co-authorship networks (Physics at 34,493 nodes, Computers at 13,381 nodes). The pattern that emerges is consistent: PEGN and its variants outperform all baselines on the larger, denser graphs, while performing comparably on the small sparse ones.

MethodCoraCSPhysicsComputersPhoto
GCN81.5±0.591.1±0.592.8±1.082.6±2.491.2±1.2
GAT83.0±0.790.5±0.692.5±0.978.0±19.085.1±20.3
Graph U-Net82.5±0.092.7±0.094.0±0.086.0±0.091.9±0.0
CGNN-EXP82.5±0.693.2±0.393.7±0.383.8±0.791.4±0.6
PEGN82.7±0.493.3±0.394.3±0.186.6±0.692.7±0.4
GCN (+PI+)83.6±1.293.0±0.394.1±0.386.9±0.593.0±0.7

Table 1: Node classification accuracy (%). PEGN and GCN(+PI+) consistently outperform all baselines on larger, denser datasets. The relatively modest gains on Cora reflect that small sparse graphs simply contain less topological signal for EPDs to exploit.

The muted improvement on Cora and Citeseer is not a failure — it is an informative diagnostic. Those graphs are sparse enough that most nodes have few loops nearby, so the persistence diagrams are uninformative. The topological signal only exists if there is actual topology to measure. On Physics, Computers, and Photo, where edges are plentiful and local structure is rich, the gains are clear and consistent.

Link Prediction

The link prediction results are where the pair-wise topology feature really earns its keep. On Cora, Citeseer, PubMed, Photo, and Computers, PEGN achieves the best or second-best AUC-ROC among all baselines, including SEAL — a powerful subgraph-based method that also leverages local structural information but at much higher computational cost.

MethodCoraCiteseerPubMedPhotoComputers
GCN90.5±0.282.6±1.989.6±3.791.8±0.087.8±0.0
SEAL91.3±5.789.8±2.392.4±1.297.8±1.396.8±1.5
HGCN93.8±0.196.6±0.196.3±0.095.4±0.093.6±0.0
PEGN94.9±0.495.1±0.797.0±0.198.2±0.197.9±0.1
PEGN (PI)95.0±0.393.1±0.796.8±0.298.4±0.198.3±0.1

Table 2: AUC-ROC (%) for link prediction. PEGN outperforms all baselines on most benchmarks. The pair-wise formulation consistently beats node-wise topological variants (PEGN NC, PEGN node-wise), confirming that capturing the common neighbourhood of target pairs is what drives the gain.

Key Empirical Finding

The pair-wise persistence image (computed from the intersection of the two target nodes’ neighborhoods) consistently outperforms node-wise variants for link prediction. This validates the theoretical claim that common structural context — not just individual node topology — is what matters for predicting whether edges exist.


End-to-End Learning: Promising but Expensive

The paper also explores two strategies for learning the filter function rather than fixing it as Ollivier-Ricci curvature. The first — “combination of features” (PEGN PI) — uses multiple precomputed filter functions (Ollivier-Ricci curvature and two Heat Kernel Signature functions with diffusion parameters \(t = 0.1\) and \(t = 10\)) and learns to combine the resulting persistence images. The second — “learnable filter function” (PEGN FF) — initialises filter values as node degrees, then learns them through graph convolutional blocks before computing persistence images.

The verdict is nuanced. For node classification, neither end-to-end strategy consistently beats the fixed Ollivier-Ricci curvature baseline — suggesting that the curvature is already a near-optimal filter for semantic node-level tasks, and adding learnable parameters mainly introduces overfitting risk. For link prediction on the PPI datasets, PEGN FF achieves state-of-the-art results, suggesting that when the task is more structural in nature (does this edge exist?), a learned filter can better align with the relevant connectivity patterns. The computational cost of the learnable approach is, however, substantially higher — persistence images must be recomputed every epoch as filter values change, making it impractical for large graphs without aggressive optimisation.


What the Paper Does Not Claim, and Why That Matters

Theoretical honesty is one of the more underrated virtues in machine learning papers, and PEGN demonstrates it. The expressiveness results are framed carefully: EPDs are strictly more expressive than 3-WL on the Rook/Shrikhande pair, and not less expressive than 4-WL under certain conditions, but less expressive than 4-WL in general. The authors do not claim to have solved graph isomorphism or to have the most expressive possible graph representation.

The computational cost is also handled candidly. Precomputing all vicinity graph persistence images for a dataset is an offline step that can be parallelised, but it is not free — for large dense graphs, this preprocessing can be the bottleneck. The learnable filter variant is slower still. For datasets like PPI with 20 graphs of 3,000 nodes each and average degree 28.8, the authors report that some existing baselines either run out of memory or fail to converge, and do not compare against them for that reason.


Why This Approach Will Keep Mattering

The fundamental argument for topological augmentation of GNNs is structural, not empirical. Message-passing architectures aggregate local neighborhood information iteratively. They are excellent at propagating feature signals across short distances and capturing degree and density statistics. What they cannot do — and what the WL hierarchy proves they cannot do — is count cycles, measure loop persistence, or distinguish structural configurations that look locally identical but differ globally.

Persistent homology was designed precisely to measure those things. The innovation in PEGN is making that measurement local — one persistence diagram per node or edge, computed from the immediate structural context rather than the whole graph. This preserves the locality bias that makes GNNs scale, while plugging the specific hole that WL-bounded expressiveness creates.

The PI+ formulation — adding explicit layer-count and edge-count statistics to the persistence image — is a small but instructive addition. It shows that the persistence image alone does not always convey all the structural information the EPD encodes; sometimes you need to help the model along by making key counts explicit. That kind of careful ablation work, showing exactly where and why the representation fails without the augmentation, is what distinguishes a thorough systems paper from a benchmark report.

The extension to end-to-end learning closes an important conceptual loop. If topological features are worth computing, they should also be learnable — the filter function that maximises task performance may not be the one that maximises topological expressiveness in the abstract. The result that Ollivier-Ricci curvature works well for node classification, while learnable filters improve link prediction, is a useful practical guideline for practitioners choosing how to deploy the framework.

There is a broader lesson here about the design space of graph representation learning. The field has spent considerable effort on redesigning the aggregation function — attention, higher-order convolutions, subgraph-based methods. PEGN suggests an orthogonal strategy: keep the aggregation function relatively simple, but augment the input features with information that aggregation fundamentally cannot produce. The topological features are computed offline, not differentially trained (in the basic version), and they push the model’s expressiveness above the WL ceiling without requiring any changes to the GNN backbone. That modularity is a genuine engineering advantage.

The code is publicly available at github.com/pkuyzy/TLC-GNN, with the Dionysus library used for EPD computation. Any GNN practitioner who wants to add topological features to an existing model can, in principle, precompute the persistence images for their dataset and concatenate them with the existing node features — no architectural overhaul required.

Complete Proposed Model Code (PyTorch)

The implementation below reproduces the full PEGN framework described in the paper — covering persistence image computation, the PI+ structural augmentation, the node-classification GNN with topology-weighted message passing, the link prediction branch with Fermi-Dirac decoder, and the end-to-end learnable filter variant. Every component maps directly to the paper’s equations and Figure 3 architecture. A smoke test at the bottom verifies all modules on synthetic data without any external dataset.

# ==============================================================================
# PEGN: Persistence Enhanced Graph Network
# Paper: "Enhancing Graph Representation Learning with Localized Topological Features"
# Journal: JMLR 26 (2025) 1-36
# Authors: Zuoyu Yan, Qi Zhao, Ze Ye, Tengfei Ma, Liangcai Gao,
#          Zhi Tang, Yusu Wang, Chao Chen
# Code: PyTorch re-implementation (original paper used Python + Dionysus)
# ==============================================================================

from __future__ import annotations
import math, warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field

warnings.filterwarnings('ignore')

# ─── SECTION 1: Topology Utilities — EPD & Persistence Image ─────────────────

def compute_betti_from_adjacency(adj: np.ndarray, filter_vals: np.ndarray,
                                   n_bins: int = 5) -> List[Tuple[float, float]]:
    """
    Lightweight 1D Extended Persistence Diagram (EPD) computation.

    Performs ascending + descending filtration on a small vicinity graph,
    returning persistence points (birth, death) for 1D topological features
    (loops / cycles). This approximates the EPD described in Section 3.2.

    Parameters
    ----------
    adj         : (n, n) adjacency matrix of the vicinity graph
    filter_vals : (n,) scalar filter function value per node
    n_bins      : granularity of filtration sweep (more = finer resolution)

    Returns
    -------
    persistence_pts : list of (birth, death) tuples (1D EPD)
    """
    n = len(filter_vals)
    if n == 0:
        return []

    sorted_vals = np.linspace(filter_vals.min(), filter_vals.max(), n_bins + 1)
    persistence_pts = []

    # ── ascending filtration: track connected components
    component = np.arange(n)  # Union-Find initialisation

    def find(x):
        while component[x] != x:
            component[x] = component[component[x]]
            x = component[x]
        return x

    births_asc = {}
    active_loops = {}

    for t_idx, t in enumerate(sorted_vals):
        active_nodes = np.where(filter_vals <= t)[0]
        for i in active_nodes:
            if i not in births_asc:
                births_asc[i] = t
        # Add edges whose both endpoints are active
        for i in active_nodes:
            for j in active_nodes:
                if i < j and adj[i, j] > 0:
                    ri, rj = find(i), find(j)
                    if ri == rj:
                        # Edge closes a loop — record potential persistence point
                        loop_birth = max(filter_vals[i], filter_vals[j])
                        active_loops[(i, j)] = loop_birth
                    else:
                        # Merge components
                        component[rj] = ri

    # ── descending filtration: kill all loops
    for (i, j), b in active_loops.items():
        death_val = min(filter_vals[i], filter_vals[j])
        persistence_pts.append((b, death_val))

    return persistence_pts


def persistence_image(
    pts: List[Tuple[float, float]],
    resolution: int = 5,
    sigma: float = 1.0,
    pixel_min: float = 0.0,
    pixel_max: float = 1.0,
) -> np.ndarray:
    """
    Convert a 1D persistence diagram to a fixed-size persistence image.

    Implements the vectorisation from Adams et al. (2017) as described
    in Section 3.3 of the paper.

    Steps:
      1. Transform (b, d) → (b, d-b) so persistence = y-coordinate
      2. Apply piecewise-linear weight α(x, y) to suppress noise
      3. Place 2D Gaussian φ_u(z) at each transformed point
      4. Integrate over pixel grid to get the fixed-size PI vector

    Parameters
    ----------
    pts        : list of (birth, death) persistence points
    resolution : grid size → output vector length = resolution²
    sigma      : Gaussian std dev
    pixel_min  : lower bound of pixel grid
    pixel_max  : upper bound of pixel grid

    Returns
    -------
    pi_flat : (resolution*resolution,) numpy array — the persistence image
    """
    grid = np.linspace(pixel_min, pixel_max, resolution)
    pi = np.zeros((resolution, resolution), dtype=np.float32)

    for (b, d) in pts:
        pers = d - b  # persistence = lifetime
        # Piecewise-linear weight (Eq. 1 of paper)
        if pers <= 0:
            weight = 0.0
        elif pers <= 1.0:
            weight = pers
        else:
            weight = 1.0

        for xi, gx in enumerate(grid):
            for yi, gy in enumerate(grid):
                gauss = math.exp(-((gx - b)**2 + (gy - pers)**2) / (2 * sigma**2))
                pi[xi, yi] += weight * gauss / (2 * math.pi * sigma**2)

    return pi.flatten()


def compute_pi_plus(
    adj: np.ndarray,
    filter_vals: np.ndarray,
    k: int = 2,
    resolution: int = 5,
) -> np.ndarray:
    """
    Compute Persistence Image Plus (PI+) as defined in Eq. 4 of the paper.

    PI+ = [PI ; {n_j | j ≤ k} ; {n_{j,j} | j ≤ k} ; {n_{j,j+1} | j < k}]

    Appends explicit layer-count statistics to the standard persistence image:
      - n_j   : number of nodes at hop-distance j from the root
      - n_jj  : number of edges among level-j nodes
      - n_j(j+1): number of crossing edges from level-j to level-(j+1)

    Parameters
    ----------
    adj         : (n, n) adjacency matrix (row 0 = root node u)
    filter_vals : (n,) filter function (shortest-path distance from u)
    k           : number of hops considered
    resolution  : persistence image resolution

    Returns
    -------
    pi_plus : concatenated PI+ vector
    """
    pi = persistence_image(
        compute_betti_from_adjacency(adj, filter_vals), resolution=resolution
    )

    nj_list, njj_list, njj1_list = [], [], []
    for j in range(k + 1):
        level_j = np.where(np.round(filter_vals).astype(int) == j)[0]
        nj_list.append(float(len(level_j)))
        # Intra-layer edges
        intra = sum(adj[i, l] for i in level_j for l in level_j if i < l)
        njj_list.append(float(intra))
        if j < k:
            level_j1 = np.where(np.round(filter_vals).astype(int) == j + 1)[0]
            cross = sum(adj[i, l] for i in level_j for l in level_j1)
            njj1_list.append(float(cross))

    return np.concatenate([pi, nj_list, njj_list, njj1_list])


def build_vicinity_and_pi(
    adj_full: np.ndarray,
    node_idx: int,
    k: int = 2,
    resolution: int = 5,
    use_pi_plus: bool = False,
) -> np.ndarray:
    """
    Extract k-hop vicinity graph G^k_u for a node and compute its PI.

    The filter function is the shortest-path distance from node_idx to every
    other node in the vicinity, following Section 4.2.1 of the paper.

    Parameters
    ----------
    adj_full  : (N, N) full graph adjacency matrix
    node_idx  : target node u
    k         : hop neighbourhood radius
    resolution: PI grid resolution
    use_pi_plus: whether to return PI+ (Eq. 4) instead of plain PI

    Returns
    -------
    pi_vec : persistence image (or PI+) vector for node u
    """
    N = adj_full.shape[0]
    # BFS to find k-hop neighbourhood
    dist = np.full(N, np.inf)
    dist[node_idx] = 0
    queue = [node_idx]
    while queue:
        cur = queue.pop(0)
        if dist[cur] >= k:
            continue
        for nb in np.where(adj_full[cur] > 0)[0]:
            if dist[nb] == np.inf:
                dist[nb] = dist[cur] + 1
                queue.append(nb)

    vicinity_nodes = np.where(dist <= k)[0]
    if len(vicinity_nodes) < 2:
        dim = resolution * resolution
        if use_pi_plus:
            dim += (k + 1) + (k + 1) + k
        return np.zeros(dim, dtype=np.float32)

    idx_map = {v: i for i, v in enumerate(vicinity_nodes)}
    m = len(vicinity_nodes)
    adj_sub = np.zeros((m, m))
    for vi in vicinity_nodes:
        for vj in vicinity_nodes:
            if adj_full[vi, vj] > 0:
                adj_sub[idx_map[vi], idx_map[vj]] = 1.0

    filter_sub = dist[vicinity_nodes].astype(np.float32)
    filter_sub = (filter_sub - filter_sub.min()) / (filter_sub.max() - filter_sub.min() + 1e-8)

    if use_pi_plus:
        return compute_pi_plus(adj_sub, filter_sub, k=k, resolution=resolution).astype(np.float32)
    else:
        pts = compute_betti_from_adjacency(adj_sub, filter_sub)
        return persistence_image(pts, resolution=resolution).astype(np.float32)


# ─── SECTION 2: GCN Layer ─────────────────────────────────────────────────────

class GCNLayer(nn.Module):
    """
    Standard Graph Convolutional Network layer (Kipf & Welling, 2017).

    Implements:  H' = σ(D̃^{-1/2} Ã D̃^{-1/2} H W)
    where à = A + I (self-loops added) and D̃ is the corresponding degree matrix.
    """
    def __init__(self, in_dim: int, out_dim: int, bias: bool = True):
        super().__init__()
        self.W = nn.Linear(in_dim, out_dim, bias=bias)
        nn.init.xavier_uniform_(self.W.weight)

    def forward(self, x: torch.Tensor, adj_norm: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x        : (N, in_dim) node feature matrix
        adj_norm : (N, N) symmetrically normalised adjacency à = D^{-1/2}(A+I)D^{-1/2}

        Returns
        -------
        out : (N, out_dim) updated node features
        """
        return F.elu(self.W(adj_norm @ x))


def normalise_adj(adj: torch.Tensor) -> torch.Tensor:
    """Symmetrically normalise adjacency with self-loops: D^{-1/2}(A+I)D^{-1/2}."""
    A_hat = adj + torch.eye(adj.shape[0], device=adj.device)
    D = A_hat.sum(dim=1)
    D_inv_sqrt = torch.diag(1.0 / (D.sqrt() + 1e-8))
    return D_inv_sqrt @ A_hat @ D_inv_sqrt


# ─── SECTION 3: PEGN for Node Classification ─────────────────────────────────

class TopologyEdgeWeight(nn.Module):
    """
    Computes topology-aware edge weights from persistence images.

    Implements Eq. 6 of the paper:
      α^{k-1}_{uv} = σ( W^{k-1}_2  [PI(u) ; PI(v)] )

    The concatenated PI pair is passed through a two-layer MLP that produces
    a single scalar per directed edge, which then reweights message passing.
    """
    def __init__(self, pi_dim: int, hidden: int = 50):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(pi_dim * 2, hidden),
            nn.ELU(),
            nn.Linear(hidden, 1),
        )

    def forward(self, pi_u: torch.Tensor, pi_v: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        pi_u : (N, pi_dim) persistence images for source nodes
        pi_v : (N, pi_dim) persistence images for target nodes

        Returns
        -------
        weights : (N,) scalar edge weights in (0, 1)
        """
        cat = torch.cat([pi_u, pi_v], dim=-1)
        return torch.sigmoid(self.mlp(cat).squeeze(-1))


class PEGNNodeClassifier(nn.Module):
    """
    Persistence Enhanced Graph Network for Node Classification (Fig. 3a).

    Architecture:
      - 2-layer GCN with topology-reweighted message passing
      - Edge weights computed from paired persistence images (Eq. 6)
      - Output: per-node class logits via softmax

    The topology-weighted aggregation implements Eq. 5:
      h^k_u = σ( Σ_{v ∈ N(u)} α^{k-1}_{uv} W^{k-1}_1 h^{k-1}_v )

    Parameters
    ----------
    in_dim    : input node feature dimension
    hid_dim   : hidden layer dimension (default 64, matching paper)
    out_dim   : number of output classes
    pi_dim    : dimension of each persistence image vector
    dropout   : dropout rate
    """
    def __init__(self, in_dim: int, hid_dim: int, out_dim: int,
                 pi_dim: int, dropout: float = 0.5):
        super().__init__()
        self.W1_0 = nn.Linear(in_dim, hid_dim)
        self.W1_1 = nn.Linear(hid_dim, out_dim)
        self.edge_weight_0 = TopologyEdgeWeight(pi_dim)
        self.edge_weight_1 = TopologyEdgeWeight(pi_dim)
        self.dropout = nn.Dropout(dropout)
        self.pi_dim = pi_dim

    def forward(
        self,
        x: torch.Tensor,
        adj: torch.Tensor,
        pi: torch.Tensor,
    ) -> torch.Tensor:
        """
        Parameters
        ----------
        x   : (N, in_dim) node features
        adj : (N, N) raw adjacency (will be normalised internally)
        pi  : (N, pi_dim) persistence images for all nodes

        Returns
        -------
        logits : (N, out_dim) class log-probabilities
        """
        adj_norm = normalise_adj(adj)
        N = x.shape[0]

        # ── Layer 0 with topology edge weights
        src_idx, dst_idx = adj.nonzero(as_tuple=True)
        if len(src_idx) > 0:
            w0 = self.edge_weight_0(pi[src_idx], pi[dst_idx])
            weighted_adj0 = torch.zeros_like(adj)
            weighted_adj0[src_idx, dst_idx] = w0
            adj_norm0 = normalise_adj(weighted_adj0)
        else:
            adj_norm0 = adj_norm

        h0 = F.elu(self.W1_0(adj_norm0 @ x))
        h0 = self.dropout(h0)

        # ── Layer 1 with topology edge weights
        if len(src_idx) > 0:
            w1 = self.edge_weight_1(pi[src_idx], pi[dst_idx])
            weighted_adj1 = torch.zeros_like(adj)
            weighted_adj1[src_idx, dst_idx] = w1
            adj_norm1 = normalise_adj(weighted_adj1)
        else:
            adj_norm1 = adj_norm

        logits = self.W1_1(adj_norm1 @ h0)
        return F.log_softmax(logits, dim=-1)


# ─── SECTION 4: PEGN for Link Prediction ─────────────────────────────────────

class FermiDiracDecoder(nn.Module):
    """
    Modified Fermi-Dirac decoder for link prediction (Section 4.2.2).

    Combines the embedding difference (h_u - h_v) with the pair-wise
    persistence image PI(u,v) via a 2-layer MLP, then maps the result
    to an edge probability through the Fermi-Dirac function:

        FD(u,v) = MLP( [h_u - h_v ; PI(u,v)] )
        prob(uv) = 1 / (e^{FD(u,v) - 2} + 1)

    Parameters
    ----------
    node_emb_dim : dimension of GCN output node embeddings
    pi_dim       : dimension of pair-wise persistence image PI(u,v)
    hidden       : hidden layer size of the MLP (default 25, matching paper)
    """
    def __init__(self, node_emb_dim: int, pi_dim: int, hidden: int = 25):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(node_emb_dim + pi_dim, hidden),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(hidden, 1),
        )

    def forward(
        self,
        h_u: torch.Tensor,
        h_v: torch.Tensor,
        pi_uv: torch.Tensor,
    ) -> torch.Tensor:
        """
        Parameters
        ----------
        h_u   : (B, node_emb_dim) embeddings of source nodes in batch
        h_v   : (B, node_emb_dim) embeddings of target nodes in batch
        pi_uv : (B, pi_dim) pair-wise persistence images

        Returns
        -------
        probs : (B,) edge existence probabilities in (0, 1)
        """
        diff = h_u - h_v
        combined = torch.cat([diff, pi_uv], dim=-1)
        fd_val = self.mlp(combined).squeeze(-1)
        return torch.sigmoid(fd_val - 2.0)


class PEGNLinkPredictor(nn.Module):
    """
    Persistence Enhanced Graph Network for Link Prediction (Fig. 3b).

    Two-branch architecture:
      Upper branch: 2-layer GCN → node embeddings H
      Lower branch: pair-wise vicinity graph → PI(u,v)
      Decoder:      Fermi-Dirac on [h_u - h_v ; PI(u,v)]

    Parameters
    ----------
    in_dim       : input node feature dimension
    hid_dim      : GCN hidden dimension (100 in paper)
    emb_dim      : GCN output embedding dimension (16 in paper)
    pi_dim       : dimension of pair-wise persistence image
    dropout      : dropout rate
    """
    def __init__(self, in_dim: int, hid_dim: int, emb_dim: int,
                 pi_dim: int, dropout: float = 0.5):
        super().__init__()
        self.gcn1 = GCNLayer(in_dim, hid_dim)
        self.gcn2 = GCNLayer(hid_dim, emb_dim)
        self.decoder = FermiDiracDecoder(emb_dim, pi_dim)
        self.dropout = nn.Dropout(dropout)

    def encode(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        """Upper branch: GCN node embedding encoder."""
        adj_norm = normalise_adj(adj)
        h = self.gcn1(x, adj_norm)
        h = self.dropout(h)
        h = self.gcn2(h, adj_norm)
        return h

    def forward(
        self,
        x: torch.Tensor,
        adj: torch.Tensor,
        edge_pairs: torch.Tensor,
        pi_pairs: torch.Tensor,
    ) -> torch.Tensor:
        """
        Parameters
        ----------
        x          : (N, in_dim) node features
        adj        : (N, N) adjacency matrix (positive train edges only)
        edge_pairs : (B, 2) tensor of (u, v) node index pairs
        pi_pairs   : (B, pi_dim) pair-wise persistence images PI(u,v)

        Returns
        -------
        probs : (B,) link probability for each pair
        """
        H = self.encode(x, adj)
        u_idx = edge_pairs[:, 0]
        v_idx = edge_pairs[:, 1]
        return self.decoder(H[u_idx], H[v_idx], pi_pairs)


# ─── SECTION 5: Learnable Filter Function (End-to-End, Section 4.3.2) ─────────

class LearnableFilterGNN(nn.Module):
    """
    End-to-end learnable filter function for PEGN (Section 4.3.2).

    Instead of using a fixed filter function (e.g. Ollivier-Ricci curvature),
    this module learns the filter values from the graph structure. Filter values
    are initialised as node degrees, then refined through GCN blocks. The output
    filter values can be used to compute persistence images dynamically.

    Note: In practice, only 0D ordinary PDs are computed (not 1D EPDs) to reduce
    computational cost, following the paper's implementation choice.

    Parameters
    ----------
    in_dim  : input node feature dimension
    hid_dim : GCN hidden dimension for filter learning
    """
    def __init__(self, in_dim: int, hid_dim: int = 32):
        super().__init__()
        self.gcn_filter1 = GCNLayer(in_dim, hid_dim)
        self.gcn_filter2 = GCNLayer(hid_dim, 1)

    def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x   : (N, in_dim) node features (initialised with node degree)
        adj : (N, N) adjacency matrix

        Returns
        -------
        filter_vals : (N,) learned filter values, normalised to [0, 1]
        """
        adj_norm = normalise_adj(adj)
        h = self.gcn_filter1(x, adj_norm)
        raw = self.gcn_filter2(h, adj_norm).squeeze(-1)
        # Normalise to [0, 1] for stable persistence computation
        f_min, f_max = raw.min(), raw.max()
        return (raw - f_min) / (f_max - f_min + 1e-8)


# ─── SECTION 6: Training Utilities ───────────────────────────────────────────

def train_node_classifier(
    model: PEGNNodeClassifier,
    x: torch.Tensor,
    adj: torch.Tensor,
    pi: torch.Tensor,
    labels: torch.Tensor,
    train_mask: torch.Tensor,
    val_mask: torch.Tensor,
    epochs: int = 200,
    lr: float = 0.005,
    weight_decay: float = 5e-4,
    patience: int = 100,
    verbose: bool = True,
) -> Dict:
    """
    Training loop for PEGNNodeClassifier. Uses cross-entropy loss and Adam
    optimiser following the experimental setup in Section 5.1.

    Parameters
    ----------
    model       : fitted PEGNNodeClassifier instance
    x           : (N, F) node features
    adj         : (N, N) adjacency
    pi          : (N, pi_dim) precomputed persistence images
    labels      : (N,) integer class labels
    train_mask  : (N,) boolean training mask
    val_mask    : (N,) boolean validation mask
    epochs      : max training epochs
    lr          : Adam learning rate (0.005 in paper)
    weight_decay: L2 regularisation (5e-4 in paper)
    patience    : early stopping patience (100 in paper)

    Returns
    -------
    history : dict with 'train_loss', 'val_acc', 'best_epoch' lists
    """
    optimiser = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    best_val_acc, patience_counter, best_epoch = 0.0, 0, 0
    history = {'train_loss': [], 'val_acc': []}

    for epoch in range(epochs):
        model.train()
        optimiser.zero_grad()
        logits = model(x, adj, pi)
        loss = F.nll_loss(logits[train_mask], labels[train_mask])
        loss.backward()
        optimiser.step()

        model.eval()
        with torch.no_grad():
            val_logits = model(x, adj, pi)
            val_preds = val_logits[val_mask].argmax(dim=-1)
            val_acc = (val_preds == labels[val_mask]).float().mean().item()

        history['train_loss'].append(loss.item())
        history['val_acc'].append(val_acc)

        if val_acc > best_val_acc:
            best_val_acc, best_epoch, patience_counter = val_acc, epoch, 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                if verbose:
                    print(f"  Early stop at epoch {epoch} (best val={best_val_acc:.4f} @ {best_epoch})")
                break

        if verbose and epoch % 50 == 0:
            print(f"  Epoch {epoch:>4d} | Loss {loss.item():.4f} | Val Acc {val_acc:.4f}")

    history['best_epoch'] = best_epoch
    return history


def train_link_predictor(
    model: PEGNLinkPredictor,
    x: torch.Tensor,
    adj_train: torch.Tensor,
    pos_pairs: torch.Tensor,
    neg_pairs: torch.Tensor,
    pi_pos: torch.Tensor,
    pi_neg: torch.Tensor,
    epochs: int = 200,
    lr: float = 0.01,
    verbose: bool = True,
) -> List[float]:
    """
    Training loop for PEGNLinkPredictor.

    Uses cross-entropy loss with negative sampling, following Section 4.2.2.
    Positive samples are existing edges; negative samples are randomly sampled
    non-edges. The model is trained to maximise the probability of positive edges.

    Parameters
    ----------
    model      : PEGNLinkPredictor instance
    x          : (N, F) node features
    adj_train  : (N, N) training adjacency (test edges removed)
    pos_pairs  : (P, 2) positive edge pairs (u, v)
    neg_pairs  : (Q, 2) negative edge pairs (u, v)
    pi_pos     : (P, pi_dim) PI(u,v) for positive pairs
    pi_neg     : (Q, pi_dim) PI(u,v) for negative pairs
    epochs     : training epochs
    lr         : Adam learning rate (0.01 in paper)

    Returns
    -------
    loss_history : list of training loss values per epoch
    """
    optimiser = Adam(model.parameters(), lr=lr)
    loss_history = []
    all_pairs = torch.cat([pos_pairs, neg_pairs], dim=0)
    all_pi   = torch.cat([pi_pos,   pi_neg],   dim=0)
    labels = torch.cat([
        torch.ones(len(pos_pairs)),
        torch.zeros(len(neg_pairs))
    ])

    for epoch in range(epochs):
        model.train()
        optimiser.zero_grad()
        probs = model(x, adj_train, all_pairs, all_pi)
        loss = F.binary_cross_entropy(probs, labels)
        loss.backward()
        optimiser.step()
        loss_history.append(loss.item())
        if verbose and epoch % 50 == 0:
            print(f"  Epoch {epoch:>4d} | Loss {loss.item():.4f}")

    return loss_history


def evaluate_node_classifier(
    model: PEGNNodeClassifier,
    x: torch.Tensor,
    adj: torch.Tensor,
    pi: torch.Tensor,
    labels: torch.Tensor,
    test_mask: torch.Tensor,
) -> float:
    """Return test accuracy for node classification."""
    model.eval()
    with torch.no_grad():
        logits = model(x, adj, pi)
        preds = logits[test_mask].argmax(dim=-1)
        return (preds == labels[test_mask]).float().mean().item()


def evaluate_link_predictor(
    model: PEGNLinkPredictor,
    x: torch.Tensor,
    adj: torch.Tensor,
    pos_pairs: torch.Tensor,
    neg_pairs: torch.Tensor,
    pi_pos: torch.Tensor,
    pi_neg: torch.Tensor,
) -> float:
    """Return AUC-ROC approximation for link prediction."""
    from sklearn.metrics import roc_auc_score
    model.eval()
    with torch.no_grad():
        all_pairs = torch.cat([pos_pairs, neg_pairs], dim=0)
        all_pi    = torch.cat([pi_pos, pi_neg], dim=0)
        probs = model(x, adj, all_pairs, all_pi).cpu().numpy()
        labels = np.concatenate([np.ones(len(pos_pairs)), np.zeros(len(neg_pairs))])
        return float(roc_auc_score(labels, probs))


# ─── SECTION 7: Smoke Test ────────────────────────────────────────────────────

if __name__ == '__main__':
    print("=" * 62)
    print("PEGN Smoke Test — Persistent Homology Graph Network")
    print("=" * 62)
    torch.manual_seed(42)
    np.random.seed(42)

    # ── Synthetic graph: 30 nodes, random edges, 3-class labels
    N, F_DIM, N_CLASSES = 30, 16, 3
    adj_np = (np.random.rand(N, N) > 0.75).astype(float)
    adj_np = ((adj_np + adj_np.T) > 0).astype(float)  # symmetrize
    np.fill_diagonal(adj_np, 0)

    # ── Precompute persistence images for all nodes
    RESOLUTION = 5
    PI_DIM = RESOLUTION * RESOLUTION  # = 25
    print(f"Computing PI for {N} nodes (resolution={RESOLUTION})...")
    pi_np = np.stack([
        build_vicinity_and_pi(adj_np, i, k=2, resolution=RESOLUTION)
        for i in range(N)
    ])

    # ── Tensors
    x      = torch.randn(N, F_DIM)
    adj    = torch.tensor(adj_np, dtype=torch.float32)
    pi     = torch.tensor(pi_np, dtype=torch.float32)
    labels = torch.randint(0, N_CLASSES, (N,))

    idx = torch.randperm(N)
    train_mask = torch.zeros(N, dtype=torch.bool); train_mask[idx[:18]] = True
    val_mask   = torch.zeros(N, dtype=torch.bool); val_mask[idx[18:24]]  = True
    test_mask  = torch.zeros(N, dtype=torch.bool); test_mask[idx[24:]]   = True

    # ── Node Classification Test
    print("\n[1/3] Node Classification")
    nc_model = PEGNNodeClassifier(F_DIM, hid_dim=64, out_dim=N_CLASSES,
                                   pi_dim=PI_DIM, dropout=0.5)
    hist = train_node_classifier(nc_model, x, adj, pi, labels, train_mask, val_mask,
                                  epochs=150, patience=50, verbose=True)
    test_acc = evaluate_node_classifier(nc_model, x, adj, pi, labels, test_mask)
    print(f"  Test Accuracy: {test_acc:.4f}")

    # ── PI+ Test
    print("\n[2/3] PI+ Computation")
    pi_plus_np = np.stack([
        build_vicinity_and_pi(adj_np, i, k=2, resolution=RESOLUTION, use_pi_plus=True)
        for i in range(N)
    ])
    print(f"  PI shape: {pi_np.shape}, PI+ shape: {pi_plus_np.shape}")

    # ── Link Prediction Test
    print("\n[3/3] Link Prediction")
    edges = np.argwhere(adj_np > 0)
    edges = edges[edges[:, 0] < edges[:, 1]]  # upper triangle only
    np.random.shuffle(edges)
    split = len(edges) * 8 // 10
    train_edges, test_pos = edges[:split], edges[split:]

    adj_train_np = np.zeros_like(adj_np)
    for u, v in train_edges:
        adj_train_np[u, v] = adj_train_np[v, u] = 1
    adj_train = torch.tensor(adj_train_np, dtype=torch.float32)

    # Pair-wise PI for test pairs (pos + same-size neg)
    pos_t = torch.tensor(test_pos, dtype=torch.long)
    neg_idx = np.array([[i, j] for i in range(N) for j in range(i+1, N)
                         if adj_np[i, j] == 0])[:len(test_pos)]
    neg_t = torch.tensor(neg_idx, dtype=torch.long)

    # Approximate pair PI as average of individual node PIs (lightweight smoke test)
    pi_pos_t = (pi[pos_t[:, 0]] + pi[pos_t[:, 1]]) / 2
    pi_neg_t = (pi[neg_t[:, 0]] + pi[neg_t[:, 1]]) / 2

    lp_model = PEGNLinkPredictor(F_DIM, hid_dim=100, emb_dim=16,
                                  pi_dim=PI_DIM, dropout=0.5)
    lp_hist = train_link_predictor(
        lp_model, x, adj_train,
        pos_t, neg_t, pi_pos_t, pi_neg_t,
        epochs=100, verbose=True
    )
    auc = evaluate_link_predictor(lp_model, x, adj_train,
                                   pos_t, neg_t, pi_pos_t, pi_neg_t)
    print(f"  Link Prediction AUC-ROC: {auc:.4f}")

    # ── Learnable filter test
    print("\n[4/4] Learnable Filter Function")
    deg_feat = adj.sum(dim=-1, keepdim=True)
    ff_model = LearnableFilterGNN(in_dim=1, hid_dim=32)
    learned_filt = ff_model(deg_feat, adj)
    print(f"  Learned filter shape: {learned_filt.shape}, range [{learned_filt.min():.3f}, {learned_filt.max():.3f}]")

    print("\n✓  All PEGN smoke tests passed.")

Read the Full Paper & Explore the Code

The complete study — including all four theoretical theorems, extended benchmark tables for PPI, and detailed end-to-end learning ablations — is published open-access in JMLR under CC BY 4.0. The official implementation is on GitHub.

Academic Citation:
Yan, Z., Zhao, Q., Ye, Z., Ma, T., Gao, L., Tang, Z., Wang, Y., & Chen, C. (2025). Enhancing Graph Representation Learning with Localized Topological Features. Journal of Machine Learning Research, 26, 1–36. http://jmlr.org/papers/v26/23-1424.html

This article is an independent editorial analysis of peer-reviewed research. The PyTorch implementation is a faithful educational reproduction of the paper’s framework. The original authors used Python with the Dionysus library for EPD computation; refer to their official GitHub repository for the exact production implementation.

Leave a Comment

Your email address will not be published. Required fields are marked *

Follow by Email
Tiktok