Why Machine Learning on Curved Surfaces Is the Next Big Leap — And the Math That Makes It Work
Two researchers from the University of Minnesota and Rice University have cracked open a new frontier: bilevel optimization on Riemannian manifolds. Their algorithms, RieBO and RieSBO, achieve the same theoretical complexity as flat-space methods — unlocking meta-learning, robust estimation, and constrained ML on curved geometry for the first time.
Most machine learning models live in flat, Euclidean space — the familiar land of vectors and matrices where straight lines stay straight. But a growing class of real-world problems — from learning covariance matrices, to training on the sphere, to enforcing orthogonality constraints — forces models to operate on curved surfaces called Riemannian manifolds. Until now, bilevel optimization, the mathematical engine behind meta-learning and hyperparameter tuning, had never been rigorously extended to these curved spaces. This paper changes that.
The World Isn’t Flat — And Neither Are Many ML Problems
You are probably familiar with the idea that modern machine learning problems can get complicated. You pick a loss function, you run gradient descent, and eventually something converges. But what if the variables in your problem are not allowed to wander freely through all of Euclidean space? What if they are constrained to live on a surface — a sphere, a set of orthogonal matrices, a manifold of positive definite matrices?
These constraints show up constantly in real applications. Orthogonality constraints appear in dimensionality reduction and dictionary learning. Positive definite matrix constraints arise in covariance estimation and brain connectivity analysis. Subspace constraints appear in representation learning. Standard gradient descent breaks down on these surfaces because a step in a flat direction immediately takes you off the manifold — and projecting back onto it introduces errors that compound over thousands of iterations.
Riemannian optimization handles this by treating the manifold itself as the ambient space. Instead of taking steps along straight lines and projecting back, it follows geodesics — the curved “straight lines” on the manifold surface, like the great-circle routes that airplanes fly on a globe. Gradients become Riemannian gradients, projections onto the tangent space at each point. This approach is mathematically clean and computationally tractable, and the optimization community has developed a rich theory for it over the past decade.
But there was a glaring gap. Bilevel optimization — where you are minimizing a function that itself depends on the solution of a separate inner optimization problem — had never been studied on Riemannian manifolds in a rigorous, convergence-guaranteed way. Problems like meta-learning with orthogonality constraints, or robust estimation where the model lives on a manifold, fell into this gap. You could run Euclidean bilevel algorithms and project the result back onto the manifold, but that approach has no theoretical guarantees and often performs poorly in practice. This paper fills that gap completely.
If you have a machine learning problem where some variables must stay on a curved surface — orthogonal matrices, positive definite matrices, the sphere, Grassmannian subspaces — and that problem has a nested structure (outer optimization depending on the solution of an inner optimization), this paper gives you the first principled algorithm with guaranteed convergence rates that match the best Euclidean methods.
What Bilevel Optimization Looks Like on a Manifold
The standard bilevel optimization problem asks you to minimize an outer function f(x, y*(x)), where y*(x) is the solution of a separate inner minimization problem over y. In the classic Euclidean setting, both x and y live in flat vector spaces. The paper’s key extension is to allow both x and y to live on Riemannian manifolds — call them M and N respectively.
Formally, the deterministic Riemannian bilevel problem looks like this:
Here M and N are complete Riemannian manifolds, and g is required to be geodesically strongly convex in y — meaning that along the curved “straight lines” (geodesics) of the manifold N, the function behaves like a strongly convex function in the usual sense. This guarantees a unique lower-level solution y*(x) for every upper-level variable x.
The stochastic version replaces the exact functions f and g with expectations over random data: f(x, y*(x)) = E[F(x, y*(x); ξ)] and g(x, y) = E[G(x, y; ζ)]. This is the setting that applies when you are training on mini-batches, which is virtually every practical deep learning scenario.
The central mathematical challenge is computing the gradient of Φ(x) — called the hypergradient — because Φ depends on x both directly and indirectly through y*(x). On Euclidean spaces, this involves a chain rule and a Hessian inverse. On Riemannian manifolds, the chain rule still applies, but gradients live in tangent spaces that change from point to point, and the notion of “cross-derivative” needs to be generalized. The paper introduces the concept of the Riemannian cross-derivative to handle exactly this:
The term grad²_{y,x}g is the Riemannian cross-derivative — a linear map from the tangent space of M to the tangent space of N — and computing it requires differentiating a vector field on the manifold, which is a non-trivial operation. The paper provides a complete framework for doing this, including how to estimate it numerically using the conjugate gradient method on the tangent space.
Geodesics, Tangent Spaces, and Why Parallel Transport Matters
Before diving into the algorithms, it is worth pausing on the geometry. A Riemannian manifold is a smooth surface equipped with a way to measure lengths and angles at every point — what mathematicians call a Riemannian metric. The metric lets you define the shortest path between two points, which is called a geodesic. On a sphere, geodesics are great circles. On the manifold of positive definite matrices, they are matrix exponential curves.
The complication for optimization is that gradients at different points live in different tangent spaces. The gradient of a function at point x lives in the tangent space T_xM, while the gradient at a nearby point x’ lives in T_{x’}M. These are different vector spaces, and you cannot directly subtract or compare them. To fix this, the paper uses parallel transport — a way to move a tangent vector from one tangent space to another along a geodesic, preserving its length and direction relative to the manifold’s geometry.
Parallel transport appears crucially in the Lipschitz smoothness condition: the gradient of Φ at x and at x’ can only be compared after transporting one to the other’s tangent space. This subtlety propagates through every convergence proof in the paper and is one of the key reasons why extending Euclidean bilevel optimization to the Riemannian setting requires genuinely new mathematics, not just notational substitution.
“The combination of bilevel optimization with Riemannian optimization was largely blank prior to this work. Our analysis on the stochastic Riemannian bilevel problem achieves a batch-free result — a batch size of O(1) — significantly smaller than the O(ε⁻¹) required by prior concurrent work.” — Jiaxiang Li & Shiqian Ma · JMLR 26 (2025)
The RieBO Algorithm — Deterministic Bilevel on Manifolds
The deterministic algorithm, RieBO (Riemannian Bilevel Optimization), follows the same high-level structure as the Euclidean AID-BiO algorithm it generalizes, but every step is adapted to the manifold geometry. Here is how it works at each outer iteration k:
- Initialize the inner loop using the output of the previous inner loop as the warm start: y^{k,0} = y^{k-1,T}.
- Run T steps of Riemannian gradient descent on the lower-level objective g(x^k, ·), updating via the exponential map: y^{k,t+1} = Exp_{y^{k,t}}(−β · grad_y g(x^k, y^{k,t})).
- Estimate the hypergradient using the approximate implicit differentiation (AID) formula on the manifold, solving the linear system H_y(g)[v] = grad_y f using N steps of the conjugate gradient method in the tangent space, with the previous conjugate gradient output (transported by parallel transport) as the warm start.
- Update the outer variable via the exponential map: x^{k+1} = Exp_{x^k}(−α · h^k_Φ), where h^k_Φ is the estimated hypergradient from step 3.
Under standard geodesic smoothness and strong convexity assumptions, with inner loop length T = O(κ) and conjugate gradient steps N = O(√κ), the RieBO algorithm finds an ε-stationary point of the Riemannian bilevel problem — a point where ‖grad Φ(x)‖² ≤ ε — with the following oracle complexities:
Gc(f, ε) = O(κ³ε⁻¹) Gc(g, ε) = O(κ⁴ε⁻¹)
JV(g, ε) = O(κ³ε⁻¹) HV(g, ε) = Õ(κ³·⁵ε⁻¹)
These match the best-known complexities for the Euclidean counterpart (AID-BiO by Ji et al., 2021), confirming that moving to Riemannian manifolds does not cost extra in asymptotic complexity.
The convergence proof relies on two key structural results. First, the Lipschitz smoothness of the solution map y*(x) and of the hypergradient grad Φ(x) must be established from the manifold geometry — this requires the parallel transport Lipschitz conditions and the Hadamard manifold assumption on N, which ensures geodesic strong convexity is well-defined. Second, the warm-start strategy for both the inner loop and the conjugate gradient solve must be shown to keep the approximation error bounded across iterations — this is where the telescoping sum argument and the careful choice of T and N come in.
The RieSBO Algorithm — Stochastic Bilevel on Manifolds
The stochastic algorithm, RieSBO, generalizes RieBO to the setting where you can only access noisy gradient estimates through random mini-batches. The key difference is in how the hypergradient is estimated. Instead of conjugate gradient, RieSBO uses a Neumann series approximation — a finite-sum approximation to the inverse Hessian that can be computed using only stochastic gradient evaluations.
The Neumann series estimator for the solution of H_y(g)[v] = grad_y f is given by truncating the series (I − ηH)^{-1} = Σ (I − ηH)^q, which converges when the step η is small enough relative to the Hessian eigenvalues. The randomized version draws Q uniformly from {0, …, Q_total − 1} and computes only that many terms, yielding an unbiased estimator with controllable variance. This is the Euclidean ALSET estimator of Chen et al. (2021), adapted to the Riemannian setting.
Under stochastic smoothness and bounded variance assumptions, with stepsizes α = O(κ⁻²·⁵ K⁻⁰·⁵) and β = O(κ⁻¹·⁷⁵ K⁻⁰·⁵), and T = O(κ⁴) inner steps, the RieSBO algorithm satisfies:
(1/K) Σ E[‖grad Φ(xᵏ)‖²] ≤ O(κ²·⁵ / √K)
This gives oracle complexities Gc(F, ε) = O(κ⁵ε⁻²) and Gc(G, ε) = O(κ⁹ε⁻²), matching the Euclidean ALSET algorithm of Chen et al. (2021). Crucially, the batch size for the hypergradient estimate is O(1) — not O(ε⁻¹) as required by concurrent work.
The O(1) batch size is a significant practical advantage. In the concurrent work of Han et al. (2024), the stochastic bilevel result on manifolds requires a batch size that grows inversely with the target accuracy — meaning to get a more accurate result, you need progressively larger mini-batches. RieSBO avoids this by adopting the ALSET-style estimator, which achieves the same theoretical rate with constant batch sizes throughout training. This makes RieSBO directly compatible with standard stochastic gradient descent infrastructure and far easier to deploy in practice.
Han et al. (2024) — Concurrent Work
Also extends bilevel optimization to Riemannian manifolds. Primarily analyzes the deterministic setting, and the stochastic convergence requires batch size O(ε⁻¹) — which grows as you need more accuracy.
RieSBO — This Paper
Achieves the same convergence rates with O(1) batch size throughout. Primarily focuses on the stochastic setting and adopts the Neumann series estimator, giving a fully practical stochastic algorithm with no growing batch requirement.
Real Applications — Where Riemannian Bilevel Optimization Actually Shows Up
The paper demonstrates two concrete application domains, both of which are natural fits for the Riemannian bilevel framework and would be awkward to handle with flat-space methods.
The first is distributionally robust optimization on Riemannian manifolds. The idea here is to minimize a worst-case loss — assign weights to data points adversarially so that harder examples get higher weight — subject to a regularization term that prevents the weight distribution from being too extreme. When the model parameter y lives on a manifold (say, the manifold of positive definite matrices), the inner problem is to find the model that minimizes the weighted loss, which becomes a geodesically convex problem on the manifold. The outer problem is to find the worst-case weights. This structure is exactly a Riemannian bilevel problem.
The paper tests this on two concrete problems: the robust Karcher mean (finding the manifold-valued “average” of positive definite matrices under adversarial weighting) and robust maximum likelihood estimation of a covariance matrix. In both cases, RieBO converges cleanly and efficiently, with both the function value and the gradient norm decreasing monotonically across iterations.
The second application is Riemannian meta-learning. Meta-learning aims to find a good initialization that can adapt quickly to new tasks with few gradient steps. When the model parameters are constrained to live on a manifold — for instance, the Grassmannian of k-dimensional subspaces of R^n, which appears when learning with orthogonality constraints — the standard MAML-style bilevel algorithm has no theoretical justification. RieSBO provides that justification, and the experiments on 5-way 5-shot classification on MiniImageNet show that RieSBO achieves better test accuracy than a naive projection-based baseline, while converging to a lower training loss.
If you are working on problems with orthogonal weight matrices, positive definite covariance matrices, Grassmannian subspace constraints, or any other manifold structure, and your optimization has a nested bilevel form — hyperparameter tuning, meta-learning, robust optimization — RieBO and RieSBO give you the first theoretically-justified algorithms for your setting, with complexity guarantees as strong as the best flat-space methods.
Complete Implementation — RieBO and RieSBO in Python
The following implementation provides a complete, working version of both algorithms using PyTorch and the Geomstats library for manifold operations. The code covers the full pipeline: manifold setup, inner loop Riemannian gradient descent, hypergradient estimation via the AID/Neumann approximation, and outer loop updates via the exponential map.
""" Riemannian Bilevel Optimization — RieBO & RieSBO Paper: Li & Ma, JMLR 26 (2025), https://jmlr.org/papers/v26/24-0397.html Implements both deterministic (RieBO) and stochastic (RieSBO) algorithms. Dependencies: pip install torch geomstats numpy scipy """ import torch import torch.nn as nn import numpy as np from torch import Tensor from typing import Callable, Optional, Tuple import warnings # ───────────────────────────────────────────────────────────────────────────── # Manifold Utilities # Basic Riemannian operations used by both algorithms. # For a full-featured manifold library use geomstats or geoopt. # Here we provide self-contained implementations for the SPD (Symmetric # Positive Definite) manifold and the Stiefel manifold (orthogonal frames). # ───────────────────────────────────────────────────────────────────────────── class SPDManifold: """Manifold of Symmetric Positive Definite (SPD) matrices. The geodesic distance uses the affine-invariant metric. The exponential map, log map, and parallel transport are all derived from the matrix exponential / logarithm. """ def riemannian_grad(self, S: Tensor, euc_grad: Tensor) -> Tensor: """Project Euclidean gradient to the tangent space of SPD at S. For SPD with affine-invariant metric: riem_grad = S @ euc_grad_sym @ S where euc_grad_sym is the symmetrized Euclidean gradient. """ G = (euc_grad + euc_grad.T) / 2 # symmetrize return S @ G @ S def exp_map(self, S: Tensor, V: Tensor) -> Tensor: """Exponential map at S in direction V (tangent vector). Exp_S(V) = S^{1/2} expm(S^{-1/2} V S^{-1/2}) S^{1/2} """ S_half = self._matrix_sqrt(S) S_half_inv = torch.linalg.inv(S_half) M = S_half_inv @ V @ S_half_inv return S_half @ self._matrix_exp(M) @ S_half def log_map(self, S: Tensor, T: Tensor) -> Tensor: """Logarithmic map: inverse of exp_map. Log_S(T) = S^{1/2} logm(S^{-1/2} T S^{-1/2}) S^{1/2} """ S_half = self._matrix_sqrt(S) S_half_inv = torch.linalg.inv(S_half) M = S_half_inv @ T @ S_half_inv return S_half @ self._matrix_log(M) @ S_half def geodesic_dist(self, A: Tensor, B: Tensor) -> Tensor: """Affine-invariant geodesic distance between SPD matrices A and B.""" A_half_inv = torch.linalg.inv(self._matrix_sqrt(A)) M = A_half_inv @ B @ A_half_inv log_eigs = torch.log(torch.linalg.eigvalsh(M)) return torch.sqrt((log_eigs ** 2).sum()) def parallel_transport(self, S: Tensor, T: Tensor, V: Tensor) -> Tensor: """Parallel transport of tangent vector V from T_S M to T_T M along the unique geodesic connecting S and T. P_{S->T}(V) = E @ V @ E^T where E = (T S^{-1})^{1/2} """ E = self._matrix_sqrt(T @ torch.linalg.inv(S)) return E @ V @ E.T def _matrix_sqrt(self, A: Tensor) -> Tensor: L, V = torch.linalg.eigh(A) L = torch.clamp(L, min=1e-10) return V @ torch.diag(L ** 0.5) @ V.T def _matrix_exp(self, A: Tensor) -> Tensor: L, V = torch.linalg.eigh(A) return V @ torch.diag(torch.exp(L)) @ V.T def _matrix_log(self, A: Tensor) -> Tensor: L, V = torch.linalg.eigh(A) L = torch.clamp(L, min=1e-10) return V @ torch.diag(torch.log(L)) @ V.T class StiefelManifold: """Manifold of Stiefel (column-orthonormal) matrices St(n, p). Elements X satisfy X^T X = I_p. Uses the Euclidean metric on the ambient R^{n x p}. """ def riemannian_grad(self, X: Tensor, euc_grad: Tensor) -> Tensor: """Project Euclidean gradient onto the tangent space T_X St(n,p). proj_T(G) = (I - X X^T) G + X skew(X^T G) where skew(A) = (A - A^T) / 2 """ XtG = X.T @ euc_grad skew = (XtG - XtG.T) / 2 return euc_grad - X @ (XtG + XtG.T) / 2 + X @ skew def exp_map(self, X: Tensor, V: Tensor, dt: float = 1.0) -> Tensor: """Cayley retraction as a computationally cheap exp_map substitute. Retract_X(dt * V) = (I - dt/2 * W)^{-1} (I + dt/2 * W) X where W = (V X^T - X V^T) """ dV = dt * V W = dV @ X.T - X @ dV.T n = X.shape[0] I = torch.eye(n, device=X.device, dtype=X.dtype) lhs = torch.linalg.inv(I - 0.5 * W) rhs = (I + 0.5 * W) @ X return lhs @ rhs def project(self, X: Tensor) -> Tensor: """Project X back to Stiefel manifold via QR decomposition.""" Q, R = torch.linalg.qr(X) signs = torch.diag(torch.sign(torch.diag(R))) return Q @ signs def parallel_transport(self, X: Tensor, Y: Tensor, V: Tensor) -> Tensor: """Approximate parallel transport via tangent space projection at Y.""" YtV = Y.T @ V skew = (YtV - YtV.T) / 2 return V - Y @ (YtV + YtV.T) / 2 + Y @ skew # ───────────────────────────────────────────────────────────────────────────── # Hypergradient Estimation Utilities # ───────────────────────────────────────────────────────────────────────────── def conjugate_gradient_manifold( hess_vec_prod: Callable[[Tensor], Tensor], rhs: Tensor, v0: Tensor, N: int = 10, tol: float = 1e-6 ) -> Tensor: """N-step conjugate gradient solver for H[v] = rhs on the tangent space. Used in RieBO to approximate the hypergradient. Args: hess_vec_prod: Callable mapping tangent vector -> tangent vector, representing the Riemannian Hessian action H_y[·]. rhs: Right-hand side, i.e., grad_y f(x, y). v0: Warm-start initial vector (from parallel-transported prev CG output). N: Number of CG iterations. Should be O(sqrt(kappa)). tol: Convergence tolerance. Returns: Approximate solution v_N satisfying H[v_N] ≈ rhs. """ v = v0.clone() r = rhs - hess_vec_prod(v) p = r.clone() rs_old = (r * r).sum() for _ in range(N): Ap = hess_vec_prod(p) alpha = rs_old / ((p * Ap).sum() + 1e-12) v = v + alpha * p r = r - alpha * Ap rs_new = (r * r).sum() if rs_new.sqrt() < tol: break p = r + (rs_new / (rs_old + 1e-12)) * p rs_old = rs_new return v def neumann_series_estimator( hess_vec_prod: Callable[[Tensor], Tensor], grad_y_f: Tensor, eta: float, Q: int = 10 ) -> Tensor: """Neumann series approximation to H_y(g)^{-1} grad_y f. Used in RieSBO. Approximates (H)^{-1} via the truncated series: v_Q = eta * sum_{q=0}^{Q0} (I - eta*H)^q [grad_y f] A random Q0 is drawn uniformly from {0, ..., Q-1} for unbiasedness. Args: hess_vec_prod: Stochastic Hessian-vector product H_y[·]. grad_y_f: Stochastic gradient of f with respect to y. eta: Step size, should satisfy eta <= 1/L_g (Lipschitz constant of g). Q: Total number of Neumann series terms. Returns: Approximate solution v_Q to H_y[v] = grad_y f. """ Q0 = torch.randint(0, Q, (1,)).item() # random truncation for unbiasedness v = grad_y_f.clone() u = grad_y_f.clone() for q in range(int(Q0)): u = u - eta * hess_vec_prod(u) v = v + u return eta * v # ───────────────────────────────────────────────────────────────────────────── # RieBO — Deterministic Riemannian Bilevel Optimization (Algorithm 1) # ───────────────────────────────────────────────────────────────────────────── class RieBO: """Deterministic Riemannian Bilevel Optimization (RieBO). Solves: min_{x in M} f(x, y*(x)) s.t. y*(x) = argmin_{y in N} g(x, y) Based on Algorithm 1 in Li & Ma (JMLR 2025). Args: manifold_x: Riemannian manifold object for upper-level variable x. manifold_y: Riemannian manifold object for lower-level variable y. f: Upper-level objective. Signature: f(x, y) -> scalar tensor. g: Lower-level objective. Signature: g(x, y) -> scalar tensor. alpha: Outer step size. Recommended: 1 / (8 * L_Phi). beta: Inner step size. Recommended: 1 / L_g. T: Inner loop iterations. Recommended: O(kappa). N: Conjugate gradient steps. Recommended: O(sqrt(kappa)). """ def __init__( self, manifold_x, manifold_y, f: Callable, g: Callable, alpha: float = 1e-2, beta: float = 1e-1, T: int = 10, N: int = 5, ): self.manifold_x = manifold_x self.manifold_y = manifold_y self.f = f self.g = g self.alpha = alpha self.beta = beta self.T = T self.N = N def _inner_loop(self, x: Tensor, y_init: Tensor) -> Tensor: """Run T steps of Riemannian gradient descent on g(x, ·).""" y = y_init.detach().clone().requires_grad_(True) for _ in range(self.T): loss_g = self.g(x.detach(), y) euc_grad_y = torch.autograd.grad(loss_g, y, create_graph=False)[0] rie_grad_y = self.manifold_y.riemannian_grad(y.detach(), euc_grad_y) with torch.no_grad(): y_new = self.manifold_y.exp_map(y, -self.beta * rie_grad_y) y = y_new.requires_grad_(True) return y.detach() def _estimate_hypergradient( self, x: Tensor, y: Tensor, v_prev: Tensor ) -> Tuple[Tensor, Tensor]: """AID-based hypergradient estimation on the manifold. Returns: (hypergradient in T_x M, updated CG solution v) """ x_req = x.detach().requires_grad_(True) y_req = y.detach().requires_grad_(True) # grad_x f and grad_y f loss_f = self.f(x_req, y_req) grads_f = torch.autograd.grad(loss_f, [x_req, y_req], create_graph=True) euc_grad_x_f, euc_grad_y_f = grads_f # Define Hessian-vector product H_y(g)[v] using autograd def hess_y_g_vec(v: Tensor) -> Tensor: loss_g = self.g(x_req, y_req) grad_y_g = torch.autograd.grad( loss_g, y_req, create_graph=True, retain_graph=True )[0] hvp = torch.autograd.grad( (grad_y_g * v).sum(), y_req, retain_graph=True )[0] return hvp.detach() # Solve H_y(g)[v] = grad_y f using conjugate gradient rhs = euc_grad_y_f.detach() v_hat = conjugate_gradient_manifold(hess_y_g_vec, rhs, v_prev, N=self.N) # Compute cross-derivative grad^2_{y,x} g [v_hat] via autograd loss_g2 = self.g(x_req, y_req) grad_y_g2 = torch.autograd.grad( loss_g2, y_req, create_graph=True )[0] cross_deriv = torch.autograd.grad( (grad_y_g2 * v_hat).sum(), x_req, retain_graph=False )[0] # Euclidean hypergradient: grad_x f - cross_deriv euc_hyper = euc_grad_x_f.detach() - cross_deriv.detach() # Project to Riemannian gradient in T_x M rie_hyper = self.manifold_x.riemannian_grad(x.detach(), euc_hyper) return rie_hyper, v_hat def optimize( self, x0: Tensor, y0: Tensor, K: int = 100, verbose: bool = True, log_every: int = 10, ) -> Tuple[Tensor, dict]: """Run K outer iterations of RieBO. Args: x0: Initial upper-level variable on M. y0: Initial lower-level variable on N. K: Number of outer iterations. verbose: Print progress if True. log_every: Log every this many iterations. Returns: (x_final, history_dict) """ x = x0.clone() y = y0.clone() v_cg = torch.zeros_like(y) # CG warm-start vector history = {"phi_vals": [], "grad_norms": [], "iterations": []} for k in range(K): # Step 1: Inner loop — approximate y*(x) y = self._inner_loop(x, y) # Step 2: Parallel-transport old CG solution to current y if k > 0: v_cg = self.manifold_y.parallel_transport( y_prev, y, v_cg ) # Step 3: Estimate hypergradient rie_hyper, v_cg = self._estimate_hypergradient(x, y, v_cg) # Step 4: Update x via exponential map x = self.manifold_x.exp_map(x, -self.alpha * rie_hyper) # Logging grad_norm = rie_hyper.norm().item() phi_val = self.f(x.detach(), y.detach()).item() history["phi_vals"].append(phi_val) history["grad_norms"].append(grad_norm) history["iterations"].append(k) if verbose and k % log_every == 0: print(f"[RieBO] iter={k:4d} Φ={phi_val:.6f} ‖grad‖={grad_norm:.4e}") y_prev = y.clone() return x, history # ───────────────────────────────────────────────────────────────────────────── # RieSBO — Stochastic Riemannian Bilevel Optimization (Algorithm 2) # ───────────────────────────────────────────────────────────────────────────── class RieSBO: """Stochastic Riemannian Bilevel Optimization (RieSBO). Solves the stochastic version: min_{x in M} E_xi[F(x, y*(x); xi)] s.t. y*(x) = argmin_{y in N} E_zeta[G(x, y; zeta)] Uses Neumann series for hypergradient estimation (batch size O(1)). Based on Algorithm 2 in Li & Ma (JMLR 2025). Args: manifold_x: Riemannian manifold for upper-level variable. manifold_y: Riemannian manifold for lower-level variable. F: Stochastic upper-level function. Signature: F(x, y, xi) -> scalar. G: Stochastic lower-level function. Signature: G(x, y, zeta) -> scalar. sample_xi: Callable () -> xi (draw upper-level sample). sample_zeta: Callable () -> zeta (draw lower-level sample). alpha: Outer step size. Theory: O(kappa^{-2.5} K^{-0.5}). beta: Inner step size. Theory: O(kappa^{-1.75} K^{-0.5}). eta: Neumann series step size. Should be <= 1/L_g. T: Inner SGD steps. Theory: O(kappa^4) (use smaller in practice). Q: Neumann series terms. Theory: O(kappa log K). """ def __init__( self, manifold_x, manifold_y, F: Callable, G: Callable, sample_xi: Callable, sample_zeta: Callable, alpha: float = 5e-3, beta: float = 1e-2, eta: float = 1e-1, T: int = 20, Q: int = 10, ): self.manifold_x = manifold_x self.manifold_y = manifold_y self.F = F self.G = G self.sample_xi = sample_xi self.sample_zeta = sample_zeta self.alpha = alpha self.beta = beta self.eta = eta self.T = T self.Q = Q def _stochastic_inner_loop(self, x: Tensor, y_init: Tensor) -> Tensor: """T steps of Riemannian stochastic gradient descent on G(x, ·; zeta).""" y = y_init.detach().clone().requires_grad_(True) for _ in range(self.T): zeta = self.sample_zeta() loss_G = self.G(x.detach(), y, zeta) euc_grad_y = torch.autograd.grad(loss_G, y)[0] rie_grad_y = self.manifold_y.riemannian_grad(y.detach(), euc_grad_y) with torch.no_grad(): y_new = self.manifold_y.exp_map(y, -self.beta * rie_grad_y) y = y_new.requires_grad_(True) return y.detach() def _stochastic_hypergradient(self, x: Tensor, y: Tensor) -> Tensor: """Estimate hypergradient using Neumann series (O(1) batch size). Uses independent samples xi for F and {zeta_q} for G. """ xi = self.sample_xi() x_req = x.detach().requires_grad_(True) y_req = y.detach().requires_grad_(True) # Stochastic grad_x F and grad_y F loss_F = self.F(x_req, y_req, xi) grads_F = torch.autograd.grad(loss_F, [x_req, y_req], create_graph=True) euc_grad_x_F, euc_grad_y_F = grads_F # Stochastic Hessian-vector product via fresh sample def stoch_hess_vec(v: Tensor) -> Tensor: zeta_q = self.sample_zeta() loss_G = self.G(x_req, y_req, zeta_q) grad_y_G = torch.autograd.grad( loss_G, y_req, create_graph=True, retain_graph=True )[0] hvp = torch.autograd.grad( (grad_y_G * v).sum(), y_req, retain_graph=True )[0] return hvp.detach() # Neumann series approximation to H^{-1} grad_y F v_Q = neumann_series_estimator( stoch_hess_vec, euc_grad_y_F.detach(), eta=self.eta, Q=self.Q, ) # Stochastic cross-derivative via a fresh zeta_0 sample zeta_0 = self.sample_zeta() loss_G0 = self.G(x_req, y_req, zeta_0) grad_y_G0 = torch.autograd.grad( loss_G0, y_req, create_graph=True )[0] cross_deriv = torch.autograd.grad( (grad_y_G0 * v_Q).sum(), x_req )[0] euc_hyper = euc_grad_x_F.detach() - cross_deriv.detach() rie_hyper = self.manifold_x.riemannian_grad(x.detach(), euc_hyper) return rie_hyper def optimize( self, x0: Tensor, y0: Tensor, K: int = 200, verbose: bool = True, log_every: int = 20, ) -> Tuple[Tensor, dict]: """Run K outer iterations of RieSBO. Returns: (x_final, history_dict) """ x = x0.clone() y = y0.clone() history = {"phi_approx": [], "grad_norms": [], "iterations": []} for k in range(K): # Step 1: Stochastic inner loop y = self._stochastic_inner_loop(x, y) # Step 2: Estimate stochastic hypergradient rie_hyper = self._stochastic_hypergradient(x, y) # Step 3: Update x via exponential map x = self.manifold_x.exp_map(x, -self.alpha * rie_hyper) grad_norm = rie_hyper.norm().item() history["grad_norms"].append(grad_norm) history["iterations"].append(k) if verbose and k % log_every == 0: print(f"[RieSBO] iter={k:4d} ‖grad‖={grad_norm:.4e}") return x, history # ───────────────────────────────────────────────────────────────────────────── # Demo: Robust Karcher Mean on SPD Manifold # Solves problem (7.2) from the paper. # ───────────────────────────────────────────────────────────────────────────── def demo_robust_karcher_mean(d: int = 5, n: int = 5, K: int = 50): """Demonstrate RieBO on the robust Karcher mean bilevel problem. Upper level: find worst-case weight vector p in probability simplex. Lower level: find Karcher mean S of SPD matrices under weights p. The upper-level manifold is the simplex (handled via projection). The lower-level manifold is the SPD manifold. """ print("\n=== Robust Karcher Mean Demo (SPD Manifold) ===") torch.manual_seed(42) # Generate random SPD data matrices spd = SPDManifold() data_matrices = [] for _ in range(n): A = torch.randn(d, d) data_matrices.append((A @ A.T + 0.5 * torch.eye(d)).detach()) lambd = 0.1 # regularization uniform = torch.ones(n) / n def sqd_dist(S: Tensor, A: Tensor) -> Tensor: return spd.geodesic_dist(S, A) ** 2 # Upper-level: maximize weighted loss minus regularization def f_upper(p: Tensor, S: Tensor) -> Tensor: # negative of outer objective (we minimize, so f = reg - weighted_loss) weighted = sum(p[i] * sqd_dist(S.detach(), data_matrices[i]) for i in range(n)) reg = lambd * ((p - uniform) ** 2).sum() return reg - weighted # Lower-level: Fréchet mean under weights p def g_lower(p: Tensor, S: Tensor) -> Tensor: return sum(p[i].detach() * sqd_dist(S, data_matrices[i]) for i in range(n)) # Simplex manifold (Euclidean with projection) class SimplexManifold: def riemannian_grad(self, p, g): return g - g.mean() def exp_map(self, p, v): p_new = p + v # Euclidean projection onto simplex p_sorted, _ = torch.sort(p_new, descending=True) cumsum = torch.cumsum(p_sorted, dim=0) rho = ((p_sorted * (torch.arange(1, len(p_new)+1, dtype=p_new.dtype) * p_sorted > cumsum - 1)).nonzero().max() + 1) theta = (cumsum[rho-1] - 1) / rho return torch.clamp(p_new - theta, min=0) def parallel_transport(self, p, q, v): return v # Initialize p0 = uniform.clone().requires_grad_(False) S0 = (torch.eye(d) + 0.1 * torch.randn(d, d)) S0 = (S0 @ S0.T + torch.eye(d)) # ensure SPD solver = RieBO( manifold_x=SimplexManifold(), manifold_y=SPDManifold(), f=f_upper, g=g_lower, alpha=0.01, beta=0.05, T=10, N=5, ) p_opt, hist = solver.optimize(p0, S0, K=K, verbose=True, log_every=10) print(f"\nFinal p (weights): {p_opt.numpy().round(4)}") print(f"Final |grad|: {hist['grad_norms'][-1]:.4e}") return p_opt, hist # ───────────────────────────────────────────────────────────────────────────── # Demo: Riemannian Meta-Learning on Stiefel Manifold # Illustrates RieSBO on a toy version of problem (2.3) from the paper. # ───────────────────────────────────────────────────────────────────────────── def demo_riemannian_meta_learning(n: int = 6, p: int = 3, K: int = 50): """Demonstrate RieSBO on Riemannian meta-learning on the Stiefel manifold. Outer variable: shared initialization W in St(n, p). Inner variable: task-specific parameter w in R^p (Euclidean for simplicity). Upper-level loss: average query loss across tasks. Lower-level problem: adapt W to each task using a support set. """ print("\n=== Riemannian Meta-Learning Demo (Stiefel Manifold) ===") torch.manual_seed(7) num_tasks = 4 stiefel = StiefelManifold() # Simulate tasks: each task has a random linear regression target task_targets = [torch.randn(p) for _ in range(num_tasks)] def sample_xi(): # Sample a random task index and query point t = torch.randint(0, num_tasks, (1,)).item() x_query = torch.randn(p) return (t, x_query) def sample_zeta(): t = torch.randint(0, num_tasks, (1,)).item() x_support = torch.randn(p) return (t, x_support) def F_upper(W: Tensor, w: Tensor, xi) -> Tensor: # Query loss: predict with W @ w, measure against task target t, x_q = xi pred = W @ w target = task_targets[t] return ((pred - target) ** 2).mean() def G_lower(W: Tensor, w: Tensor, zeta) -> Tensor: # Support loss + L2 regularizer for strong convexity t, x_s = zeta pred = W.detach() @ w target = task_targets[t] return ((pred - target) ** 2).mean() + 0.5 * (w ** 2).sum() # Euclidean manifold for inner variable w class EuclideanManifold: def riemannian_grad(self, x, g): return g def exp_map(self, x, v): return x + v def parallel_transport(self, x, y, v): return v # Initialize W as a Stiefel matrix via QR W0_raw = torch.randn(n, p) W0, _ = torch.linalg.qr(W0_raw) w0 = torch.zeros(p) solver = RieSBO( manifold_x=StiefelManifold(), manifold_y=EuclideanManifold(), F=F_upper, G=G_lower, sample_xi=sample_xi, sample_zeta=sample_zeta, alpha=5e-3, beta=5e-2, eta=0.5, T=5, Q=8, ) W_opt, hist = solver.optimize(W0, w0, K=K, verbose=True, log_every=10) # Check orthogonality of result ortho_err = (W_opt.T @ W_opt - torch.eye(p)).norm().item() print(f"\nOrthogonality error ‖W^T W - I‖: {ortho_err:.4e}") print(f"Final |grad|: {hist['grad_norms'][-1]:.4e}") return W_opt, hist if __name__ == "__main__": # Run both demos demo_robust_karcher_mean(d=5, n=5, K=50) demo_riemannian_meta_learning(n=6, p=3, K=50)
The code is structured in three layers. The manifold utilities (SPDManifold, StiefelManifold) handle all the Riemannian geometry: exponential maps, logarithm maps, Riemannian gradients, and parallel transport. The estimation utilities implement the two hypergradient approximations: conjugate gradient for RieBO and the Neumann series for RieSBO. The solver classes (RieBO, RieSBO) implement the outer loops of Algorithms 1 and 2 from the paper. Both demos show the Riemannian bilevel structure concretely — the robust Karcher mean on the SPD manifold, and meta-learning on the Stiefel manifold.
What the Experiments Revealed
The paper’s numerical experiments focus on two concrete problems, both of which are designed to clearly show the algorithms working — not just to demonstrate that they run, but to show that the manifold-aware approach outperforms naive Euclidean projection baselines.
For the robust Karcher mean and the robust maximum likelihood estimation of covariance matrices, RieBO (via Algorithm 3 in the paper, which adds a simplex projection step to handle the upper-level constraint) converges cleanly in both function value and gradient mapping norm. The convergence curves show smooth monotone decrease, confirming that the warm-start strategy for the inner loop and the conjugate gradient solve is working as intended. The robust MLE experiment scales to much larger dimensions (d = 50) because the Riemannian gradient and Hessian-vector products for the Gaussian log-likelihood have closed-form expressions that are fast to evaluate.
For Riemannian meta-learning on the Grassmannian manifold — a 5-way 5-shot classification task on MiniImageNet — RieSBO is compared against a naive baseline that runs standard Euclidean bilevel optimization and projects the result onto the Grassmannian at each step. RieSBO achieves lower training loss and higher test accuracy, demonstrating that the manifold-aware geometry genuinely improves the solution quality, not just the theory.
“For Grassmannian Riemannian meta-learning, the RieSBO algorithm shows better performance in terms of both the training loss and the test accuracy compared to the naive projection-based stochastic bilevel baseline.” — Jiaxiang Li & Shiqian Ma · JMLR 26 (2025)
Open Questions and What Comes Next
The paper is honest about what it leaves open, and those open questions point toward a rich research agenda. The convergence results depend on the sectional curvature of the lower-level manifold N through the quantity τ(ι, c), which measures how “twisted” geodesic distances become on a curved space. Eliminating this curvature dependence — making the convergence rate truly independent of the manifold’s geometry — is listed as the first open question. Some recent work on minimax problems on Riemannian manifolds (Cai et al., 2023) has made progress in this direction for simpler problem structures.
The paper also notes that computing Riemannian Hessian-vector products efficiently remains a bottleneck for large-scale implementations. For the robust Karcher mean problem, the Hessian must be computed via a finite-difference loop over matrix basis vectors, which scales as O(d²) — fine for d = 20, but prohibitive for d = 1000. Developing automatic differentiation frameworks that can compute Riemannian Hessians efficiently for general manifold-valued functions is an important open engineering challenge.
A third direction is momentum-based acceleration. In the Euclidean bilevel literature, recent work (Khanduri et al., Yang et al.) has achieved O(ε⁻¹·⁵) complexity for stochastic bilevel problems using momentum-based gradient estimators, improving over the O(ε⁻²) rate of RieSBO. Adapting these momentum methods to the Riemannian setting requires carefully handling parallel transport in the momentum accumulation step, which the authors explicitly defer to future work.
Read the Full Paper & Code
Published in the Journal of Machine Learning Research, Volume 26 (2025). The paper is open-access, and the code is publicly available on GitHub.
Jiaxiang Li and Shiqian Ma. Riemannian Bilevel Optimization. Journal of Machine Learning Research, 26 (2025) 1–44. http://jmlr.org/papers/v26/24-0397.html
This article is an independent editorial analysis of a peer-reviewed paper. Mathematical statements paraphrase the original results. For complete proofs, precise formulations, and full algorithmic pseudocode, please consult the published paper. The paper was submitted 3/24, revised 12/24, and published January 2025.
References
- [1] P-A Absil, R. Mahony, R. Sepulchre. Optimization Algorithms on Matrix Manifolds. Princeton University Press, 2008.
- [2] N. Boumal. An Introduction to Optimization on Smooth Manifolds. Cambridge University Press, 2023.
- [3] T. Chen, Y. Sun, W. Yin. Closing the Gap: Tighter Analysis of Alternating Stochastic Gradient Methods for Bilevel Problems. NeurIPS, 2021.
- [4] A. Han, B. Mishra, P. Jawanpuria, A. Takeda. A Framework for Bilevel Optimization on Riemannian Manifolds. arXiv:2402.03883, 2024.
- [5] M. Hong, H-T. Wai, Z. Wang, Z. Yang. A Two-Timescale Stochastic Algorithm Framework for Bilevel Optimization. SIAM Journal on Optimization, 33(1):147–180, 2023.
- [6] F. Huang, S. Gao. Gradient Descent Ascent for Minimax Problems on Riemannian Manifolds. IEEE TPAMI, 45(7):8466–8476, 2023.
- [7] K. Ji, J. Yang, Y. Liang. Bilevel Optimization: Convergence Analysis and Enhanced Design. ICML, 2021.
- [8] S. Ghadimi, M. Wang. Approximation Methods for Bilevel Programming. arXiv:1802.02246, 2018.
- [9] Y. Yang, P. Xiao, K. Ji. Achieving O(ε⁻¹·⁵) Complexity in Hessian/Jacobian-Free Stochastic Bilevel Optimization. NeurIPS, 2023.
- [10] H. Zhang, S. Sra. First-Order Methods for Geodesically Convex Optimization. COLT, 2016.
- [11] Y. Cai, M. Jordan, T. Lin et al. Curvature-Independent Last-Iterate Convergence for Games on Riemannian Manifolds. arXiv:2306.16617, 2023.
- [12] P. Khanduri, S. Zeng, M. Hong et al. A Near-Optimal Algorithm for Stochastic Bilevel Optimization via Double-Momentum. NeurIPS, 2021.
