Natural gradient via optimal transport
Abstract: We study a natural Wasserstein gradient flow on manifolds of probability distributions with discrete sample spaces. We derive the Riemannian structure for the probability simplex from the dynamical formulation of the Wasserstein distance on a weighted graph. We pull back the geometric structure to the parameter space of any given probability model, which allows us to define a natural gradient flow there. In contrast to the natural Fisher-Rao gradient, the natural Wasserstein gradient incorporates a ground metric on sample space. We illustrate the analysis of elementary exponential family examples and demonstrate an application of the Wasserstein natural gradient to maximum likelihood estimation.
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.
Top Community Prompts
Explain it Like I'm 14
Overview
This paper connects two powerful ideas from math and machine learning to design a smarter way to “go downhill” when training models. It builds a new kind of natural gradient method that takes into account how similar or different outcomes are. The key tool is the Wasserstein (Earth Mover’s) distance, which measures how hard it is to shift one pile of probability mass into another when moving mass across a map has a cost. The result is a “Wasserstein natural gradient” that can be used to optimize model parameters, especially for data that live on discrete categories (like words, classes, or nodes in a graph).
Key questions the paper asks
- How can we measure distances and directions between probability distributions in a way that respects a chosen “ground” map of how outcomes relate to each other?
- How can we bring that notion of distance down to the parameter space of a statistical model, so we can run a natural gradient method there?
- How does this compare to the classic Fisher-Rao natural gradient, which treats categories as unrelated?
- What does this look like for discrete data (where you can’t slide continuously between points), and how can we compute it?
How the researchers approached the problem
To make the ideas approachable, imagine three layers:
Two big ideas behind the scenes
- Optimal transport and the Wasserstein distance: Think of two histograms as two piles of sand. The Earth Mover’s distance asks: what’s the least “work” needed to move one pile into the other, where “work” = amount of sand moved × how far you move it? This adds a “ground metric,” a map that says how far apart the bins are.
- Information geometry and natural gradients: When optimizing over probabilities, not all directions are equally meaningful. A “natural gradient” is like the steepest descent but measured with a geometry that fits probability space. The classic choice uses the Fisher-Rao metric, which ignores any physical or semantic distances between categories.
Making it work on discrete data using graphs
For discrete outcomes (like categories 1, 2, 3, …), the authors represent the ground relationships with a graph:
- Nodes = outcomes (categories).
- Edges = “neighbors” (which outcomes are similar or can move into each other).
- Edge weights = how easy or hard it is to move probability mass between two outcomes (shorter edges mean easier moves).
They use a dynamic view of optimal transport, which imagines probability “flowing” along the graph’s edges over time, like traffic on roads. The “energy” of a move depends on:
- How fast probability flows along each edge.
- How much probability is present at the connected nodes.
- How costly the edge is (from the ground metric).
From this, they build a geometry (a way to measure lengths and angles) on the space of histograms. A key matrix, built from the graph and the current probabilities, plays the role of a “Laplacian” and captures how flows cost energy.
Important detail: in discrete spaces, the dynamic view of optimal transport is not the same as the classic static, linear programming view. Only the dynamic view here gives a smooth, Riemannian-like geometry suitable for defining natural gradients.
Bringing the geometry to model parameters
A statistical model has parameters θ (the knobs we tune) and outputs a distribution p(θ). The authors “pull back” the geometry from the space of distributions onto the parameter space:
- They combine the graph-based Laplacian with the model’s Jacobian (how p changes with θ).
- This gives a metric tensor G(θ), which tells you how to measure steps in parameter space in a way that respects the ground metric on outcomes.
With G(θ) in hand, the Wasserstein natural gradient is:
- Update rule: θ̇ = − G(θ){-1} ∇θ F(p(θ))
- In words: move in the steepest descent direction, but scaled by the inverse of the new geometry. This is the same natural-gradient recipe you may have seen, but now the geometry is derived from optimal transport on a graph.
They also describe two practical update schemes:
- Forward (explicit) updates: simple, standard “natural gradient” steps with a step size.
- Backward (implicit) updates, also known as the JKO scheme: more stable, often used in Wasserstein gradient flows.
What they found
- A new natural gradient for discrete probability models: They rigorously construct a metric on the probability simplex (all histograms whose probabilities sum to 1) using optimal transport on a graph, then pull it back to model parameters to get a Wasserstein natural gradient.
- It depends on a ground metric: Unlike the Fisher-Rao metric, this geometry respects similarities between categories. If two outcomes are “close” on the graph, moving probability between them is cheaper; the gradient method will naturally prefer such moves.
- Explicit formulas: They give clear formulas for the metric G(θ) and the gradient flow. The key building blocks are:
- The weighted graph Laplacian that depends on the current probability distribution.
- The model’s Jacobian (how outputs change with parameters).
- Displacement convexity on parameter space: They extend a concept from optimal transport called displacement convexity (roughly, “convex along the special shortest paths defined by OT”) to the parameter space. This helps reason about when a loss function is nicely shaped and easier to optimize.
- Illustrative examples: On a simple 3-node line graph (1–2–3), geodesics (“shortest paths” in probability space) visibly curve toward the middle node. This shows how the ground metric changes the geometry: outcome 2 behaves as a hub, and the shortest path reflects that. They also show how this works for exponential-family models and how to apply it to maximum likelihood estimation.
Why this matters
- Better inductive bias: Many tasks have outcomes that are not all equally different. For example, in handwritten digits, 3 is more like 8 than like 0; in language, some words are closer in meaning. This method lets you bake that knowledge into the optimization through a graph-based ground metric.
- More meaningful steps: The Wasserstein natural gradient respects the actual “cost to move probability” between categories. This can lead to smoother, more sensible updates and potentially faster or more reliable training.
- Works directly for discrete data: The paper tailors the continuous ideas of optimal transport to discrete settings, which are common in practice.
Potential impact and applications
This approach can help whenever outcomes have a natural notion of distance or similarity. Examples include:
- Text and LLMs (words connected by a semantic or co-occurrence graph)
- Recommender systems (items connected by similarity)
- Image or grid data (pixels or regions connected by spatial adjacency)
- Structured prediction and classification tasks (labels connected by a taxonomy)
In practice, the method provides:
- A principled way to define a natural gradient that “knows” the structure of your sample space.
- Two practical optimization schemes (forward natural gradient and the JKO scheme) for training.
Overall, the paper offers a clear bridge between optimal transport and information geometry in the discrete setting, giving machine learning practitioners a new tool to incorporate prior knowledge about how outcomes relate.
Knowledge Gaps
Knowledge gaps, limitations, and open questions
The paper introduces a Wasserstein-based natural gradient on parametric statistical models over discrete sample spaces; the following unresolved issues and open directions remain:
- Scope limitation to discrete sample spaces: no development of the pullback Wasserstein metric and corresponding gradient flow for parametric models on continuous sample spaces, where computing the metric would involve PDE operators and functional-analytic issues.
- Dependence on chosen ground metric/graph: no guidance on how to select, validate, or learn the graph topology and edge weights ω (or distances dG), despite these choices critically shaping the geometry and optimization behavior.
- Alternative edge-weight conventions: the construction uses the arithmetic edge weight (p_i + p_j)/2; the consequences of using alternatives (e.g., logarithmic mean as in Maas-type discrete OT) for convexity, curvature, and algorithmic performance are not analyzed.
- Graph connectivity assumptions: the framework implicitly requires a connected graph (to avoid multiple zero eigenvalues of L(p)); behavior and remedies for disconnected or nearly disconnected graphs are not discussed.
- Sensitivity and scaling of edge weights: the impact of global rescaling or local reweighting of ω on convergence speed, conditioning, and the optimization path is not quantified or normalized (e.g., whether rescaling is equivalent to a time change).
- Boundary behavior: analysis is restricted to the interior of the simplex (p_i > 0); numerical and theoretical issues when probabilities approach zero (ill-conditioned L(p), numerical instability) are not addressed, nor are remedies (e.g., entropic barriers or interior-point safeguards).
- Parameterization rank and identifiability: the method assumes full-rank J_θp(θ); no treatment is given for near-singular Jacobians, overparameterized models, or parameter identifiability issues and their effects on the metric and optimization.
- Metric invariance beyond a fixed model map: while the metric is defined as a pullback via p(θ), explicit discussion of invariance under reparameterizations that represent the same statistical model (and practical implications for implementation) is missing.
- Computational scalability: each step requires forming and (pseudo-)inverting L(p(θ)) and the d×d metric G(θ); complexity and memory costs for large state spaces (large n) and strategies for scalable approximation (e.g., sparsity, iterative solvers, low-rank structure, preconditioning) are not provided.
- Efficient geodesic distance computation: the JKO scheme needs Dist(·,·) on parameter space; computational methods or approximations for this distance (or for solving the geodesic boundary value problem on (Θ, g)) are not given.
- Convergence guarantees: there are no theoretical convergence rates or step-size conditions for the forward-Euler (natural gradient) method or for the JKO scheme on (Θ, g), even under smoothness or (displacement) convexity assumptions.
- Geodesic convexity conditions in practice: while a general displacement convexity condition is stated, concrete verification for common ML objectives (e.g., cross-entropy on softmax models) and typical graphs is absent.
- Curvature and geometric properties: sectional/Ricci curvature, completeness, and existence/uniqueness of geodesics on the pullback manifold (Θ, g) are not characterized, limiting theoretical understanding of optimization dynamics.
- Comparison with Fisher-Rao natural gradient: beyond conceptual differences, no systematic analysis (theoretical or empirical) clarifies when Wasserstein natural gradient outperforms Fisher natural gradient or standard optimizers, and why.
- Robustness to misspecified ground metrics: the effect of incorrect or noisy graph/metric choices on optimization outcomes and statistical performance is not quantified.
- Joint learning of the ground metric: there is no formulation for learning ω (or dG) together with model parameters, nor constraints ensuring well-posedness (e.g., connectivity, positive weights) and regularization to avoid degenerate geometries.
- Stochastic optimization: the paper does not address mini-batch/online gradients, unbiased estimators of ∇_pF and their variance, or how stochasticity interacts with the parameter-dependent metric G(θ).
- Numerical stability: no discussion of regularization strategies (e.g., damping, Tikhonov on L(p)† or G(θ){-1}), line-search, or trust-region methods adapted to the Wasserstein metric to ensure stable updates.
- Handling non-smooth objectives: although the JKO scheme is suggested for non-smooth F, practical algorithms for the proximal step (which requires Dist) and convergence guarantees in non-smooth settings are not provided.
- Large-scale applications: empirical validation is limited to toy examples; there are no experiments on realistic datasets, no runtime benchmarks, and no ablation studies on graph choices or comparison to baselines (Fisher NG, Adam, etc.).
- Extensions beyond simple parametric families: although the framework is said to apply to arbitrary models, the paper does not demonstrate implementation for high-dimensional models (e.g., deep networks) where n and d are large and Jacobians are costly.
- Relation to static OT on discrete spaces: the practical and theoretical implications of the inequivalence between dynamic (Benamou–Brenier-type) and static (LP-based) OT in discrete settings for statistical tasks remain unexplored (e.g., which is preferable for learning, and under what conditions).
Practical Applications
Overview
The paper introduces a natural gradient method for statistical models on discrete sample spaces that is induced by the dynamic (Benamou–Brenier) L2-Wasserstein geometry on a weighted graph over the sample space. By pulling back this geometry to parameter space via the model map θ ↦ p(θ), it yields:
- A Riemannian metric on parameters, G(θ) = J(θ)ᵀ L(p(θ)) J(θ), where L(p) is a probability-weighted graph Laplacian and J is the Jacobian of p(θ).
- A Wasserstein natural gradient flow: θ̇ = −G(θ)⁻¹ ∇θF(p(θ)), and discrete-time algorithms (forward “natural gradient” and backward “JKO” proximal scheme).
- A framework to encode a ground metric (similarity) among discrete outcomes directly into optimization over model parameters.
- An extension of displacement convexity to parameter space for analysis and guarantees.
Below are practical applications derived from the findings, organized by immediacy, with sectors, potential tools/workflows, and feasibility notes.
Immediate Applications
These can be deployed now with modest engineering effort using existing ML/optimization stacks, provided a sensible ground metric (graph) on discrete outcomes is available.
- Wasserstein natural gradient optimizer for categorical models (software, ML)
- Use case: Replace Euclidean or Fisher natural gradient in training models with discrete outputs (e.g., multinomial logistic/softmax regression, LLMs over vocabularies, topic models) using the paper’s θ-update: θ ← θ − η G(θ)⁻¹ ∇θF(p(θ)).
- Sectors: Software/ML platforms, NLP, recommender systems.
- Tools/workflow:
- Define label/item similarity graph G = (V, E, ω); compute sparse D, Λ(p), L(p) = DᵀΛ(p)D.
- Compute J(θ) and ∇pF(p(θ)); form G(θ) = JᵀLJ and apply G(θ)⁻¹ via iterative SPD solvers (CG) or low-rank structure.
- Integrate as an optimizer in PyTorch/TensorFlow; expose ω and graph construction as a module.
- Assumptions/dependencies:
- Meaningful ground metric on labels/items; p(θ) strictly positive with full-rank J(θ).
- Scalability hinges on sparse graphs and efficient linear solvers.
- Maximum likelihood estimation with Wasserstein natural gradient (industry/academia)
- Use case: Fit exponential-family or multinomial models with improved preconditioning that respects label similarity; useful when labels are structured (e.g., hierarchies/taxonomies).
- Sectors: E-commerce (product taxonomy), healthcare (ICD/ATC code hierarchies), genomics (GO terms).
- Tools/workflow: Drop-in replacement of optimizer in MLE pipelines (scikit-learn, statsmodels) with WNG; optionally tune the graph from metadata or embeddings.
- Assumptions/dependencies: Availability of a label graph; numerical stability for large labels.
- Label-aware training for imbalanced/rare classes (healthcare, education, NLP)
- Use case: When errors between “nearby” classes are less severe, WNG updates steer probability mass across adjacent labels in the graph, improving learning for sparse/rare labels.
- Sectors: Medical coding (ICD prediction), skill tagging (educational taxonomies), phoneme/character recognition.
- Tools/workflow: Construct ω from domain ontology or confusion statistics; use WNG-based optimizer for cross-entropy training.
- Assumptions/dependencies: Correctly specified similarity graph; careful step-size tuning.
- Policy-gradient updates with action similarity (robotics, RL)
- Use case: Policy optimization over discrete actions where actions have known similarities (e.g., adjacent headings, motor primitives). WNG provides a geometry-aware preconditioner.
- Sectors: Robotics, game AI, operations research.
- Tools/workflow: Define action graph; replace vanilla policy gradient’s preconditioning with G(θ)⁻¹ for the policy’s categorical outputs.
- Assumptions/dependencies: Action similarity must be meaningfully defined; overhead must be acceptable in on-policy updates.
- Variational inference over discrete latent distributions (academia/software)
- Use case: Optimize variational parameters of categorical/mixture distributions using WNG to accelerate and stabilize convergence (e.g., topic-word distributions).
- Sectors: Bayesian ML, probabilistic programming.
- Tools/workflow: Integrate WNG into VI solvers for categorical factors; use word/item similarity graphs from embeddings as ω.
- Assumptions/dependencies: Positivity constraints and full rank; scalability with large vocabularies requires sparse graphs.
- Distribution shift and histogram smoothing over discrete bins (policy, analytics)
- Use case: For small to medium n, compute Wasserstein geodesics/flows to visualize/evaluate shifts among categorical histograms where bins have adjacency (e.g., geography, age groups).
- Sectors: Public policy, epidemiology, social science, marketing analytics.
- Tools/workflow: Build adjacency graph (e.g., spatial neighbors); compute geodesic paths or W-gradient flows to interpolate/denoise counts.
- Assumptions/dependencies: n moderate; continuous-time geodesic computation feasible; sensitivity to chosen adjacency.
- Recommender systems with item-taxonomy-aware probability updates (industry)
- Use case: Next-item/category prediction where items have a taxonomy/similarity graph; WNG encourages smooth mass movement across similar items, improving user-perceived relevance.
- Sectors: E-commerce, media platforms.
- Tools/workflow: Build ω from product hierarchy or embedding k-NN graph; deploy WNG for softmax head optimization.
- Assumptions/dependencies: Graph quality determines gains; manage computational cost for large catalogs.
- Regularization via JKO-like proximal step for non-smooth objectives (software/optimization)
- Use case: When the loss is non-smooth or stiff, use the backward (JKO) step with Dist(θ, θᵏ) as a geometry-aware proximal term to stabilize updates.
- Sectors: Optimization libraries, applied ML pipelines.
- Tools/workflow: Implement Dist via energy functional; solve the proximal step with inner optimization; use for robust training or denoising.
- Assumptions/dependencies: Computing Dist is more expensive than forward updates; suited to smaller models or offline training.
Long-Term Applications
These require further research, scaling advances, or methodological extensions before broad deployment.
- Scaling to very large discrete spaces and deep networks (software, NLP, recommender systems)
- Vision: Efficient WNG for vocabularies/items with 10⁵–10⁶ nodes using structured/sparse graphs, low-rank Jacobian surrogates, and fast Laplacian solvers (e.g., multigrid, graph sparsifiers).
- Potential tools: Specialized CUDA kernels for L(p)·v and G(θ)·v products; block-diagonal or Krylov preconditioning; quasi-Newton approximations respecting W-geometry.
- Dependencies: Research into approximate inverses for G(θ); variance-controlled stochastic estimates of JᵀLJ.
- Joint learning of the ground metric and model parameters (ML research, fairness)
- Vision: Meta-learning ω (graph weights) alongside θ to discover outcome similarities that optimize task performance while encoding constraints (e.g., fairness or domain priors).
- Potential products: Auto-graph construction modules; bilevel optimization frameworks.
- Dependencies: Identifiability, regularization of ω to prevent degenerate solutions; interpretability and bias audits.
- Continuous-state and hybrid extensions (academia, robotics, control)
- Vision: Extend pullback-Wasserstein geometry to continuous or mixed spaces with tractable solvers (e.g., Gaussian families, normalizing flows), enabling geometry-aware learning in continuous control and density estimation.
- Potential tools: PDE-inspired solvers, operator learning; coupling with score-based models.
- Dependencies: Robust numerical methods for continuous L(ρ) and its pullbacks; stability analysis.
- Certified optimization via displacement convexity in parameter space (theory to practice)
- Vision: Use parameter-space displacement convexity to design objectives/priors with global convergence guarantees and adaptive step sizes.
- Potential products: Optimizers with curvature-adaptive schedules based on Γ-calculus diagnostics; certified training regimes for certain loss families.
- Dependencies: Practical estimators of curvature terms (Γ operators, second fundamental form) at scale.
- Probabilistic forecasting and portfolio allocation with asset similarity (finance)
- Vision: Optimize probability allocations over assets/sectors with a similarity graph derived from correlations or fundamentals, reducing churn toward dissimilar assets.
- Potential tools: WNG-based forecast calibration; proximal updates with Dist-based penalties for turnover control.
- Dependencies: Stable, robust asset-similarity metrics; handling time-varying ω.
- Privacy-preserving histogram mechanisms using W-geometry (policy, privacy tech)
- Vision: Design differentially private mechanisms over histograms that respect ground-metric utility, leveraging W-proximal steps to denoise while preserving meaningful structure.
- Potential tools: JKO-based post-processing; W-aware DP budget allocation.
- Dependencies: Formal DP analyses under W-based post-processing; sensitivity to graph choice.
- Belief-space planning and filtering in discrete POMDPs (robotics)
- Vision: Use WNG for belief updates over discrete state spaces with known adjacency, yielding smoother and more physically plausible belief evolution.
- Potential tools: Geometry-aware filters; planning with Dist-regularized objectives.
- Dependencies: Integration with POMDP solvers; computational overhead in real time.
- Image/color/texture morphing via discrete dynamic OT geodesics (graphics/vision)
- Vision: Use discrete W-geodesics on color/texture histograms for perceptually smooth transitions and denoising.
- Potential tools: Plugins for DCC tools; histogram interpolation libraries.
- Dependencies: Efficient small-n implementations already feasible; extension to high-dimensional descriptors needs work.
Key Assumptions and Dependencies (Cross-cutting)
- Ground metric availability and quality: Success hinges on a well-chosen graph over outcomes/actions/items; misspecification can hurt performance or encode bias.
- Model smoothness and positivity: p(θ) must be smooth with strictly positive probabilities and full-rank Jacobian in the region of interest.
- Computational feasibility: L(p) and G(θ) must be applied/inverted efficiently; sparsity and iterative solvers are essential for large problems.
- Discrete dynamic OT vs static OT: The geometry arises from the dynamic (Maas/Chow) formulation on graphs, not the LP-based Kantorovich distance; interpretations and behavior differ.
- Stability and step sizing: Forward (natural gradient) updates require careful step-size control; JKO (proximal) updates are more stable but more computationally intensive.
By integrating this Wasserstein natural gradient into existing optimization workflows whenever a meaningful outcome similarity structure exists, practitioners can encode valuable priors directly into parameter updates, often with straightforward engineering effort in small to medium discrete spaces, and with a clear path toward scalable variants for larger systems.
Collections
Sign up for free to add this paper to one or more collections.