Optimal Transport for Machine Learners
Abstract: Optimal Transport is a foundational mathematical theory that connects optimization, partial differential equations, and probability. It offers a powerful framework for comparing probability distributions and has recently become an important tool in machine learning, especially for designing and evaluating generative models. These course notes cover the fundamental mathematical aspects of OT, including the Monge and Kantorovich formulations, Brenier's theorem, the dual and dynamic formulations, the Bures metric on Gaussian distributions, and gradient flows. It also introduces numerical methods such as linear programming, semi-discrete solvers, and entropic regularization. Applications in machine learning include topics like training neural networks via gradient flows, token dynamics in transformers, and the structure of GANs and diffusion models. These notes focus primarily on mathematical content rather than deep learning techniques.
Summary
- The paper introduces a unified framework combining Monge and Kantorovich formulations to efficiently compute optimal transport plans between probability distributions.
- It presents scalable computational methods, including entropic regularization via the Sinkhorn algorithm, to handle large-scale machine learning problems.
- The study demonstrates how optimal transport enhances generative modeling and domain adaptation by providing meaningful geometric metrics for comparing complex data.
Optimal Transport (OT) is a mathematical framework for comparing probability distributions or measures by finding the most efficient way to move mass from one distribution to another. Originating from Monge's problem in the 18th century and formalized by Kantorovich in the 20th century as a linear programming problem, OT has recently gained significant traction in machine learning for its ability to provide a geometrically meaningful distance between distributions, particularly useful for structured data like images, text, or point clouds where traditional measures like Kullback-Leibler divergence fall short because they are infinite for non-overlapping supports. These notes (2505.06589) cover the fundamental mathematical concepts, computational methods, and applications of OT relevant to machine learning practitioners.
Mathematical Foundations
The core idea of OT is finding a "transport plan" that minimizes the total cost of moving mass from a source distribution α to a target distribution β, given a cost function c(x,y) between points x in the source space X and y in the target space Y.
- Monge Problem: This is the original formulation, seeking a map T:X→Y such that T♯α=β (meaning T pushes forward α to β) and minimizes ∫Xc(x,T(x))dα(x).
- Discrete Case: For two sets of points {xi}i=1n and {yj}j=1m, the simplest case is matching n=m points with uniform mass $1/n$. This is the optimal assignment problem. Naively, it has n! permutations. For 1D points and convex costs c(x,y)=h(∣x−y∣), the optimal map is sorting the points, leading to an O(nlogn) solution. A practical application is grayscale histogram equalization in image processing, where the CDFs of the source and target image histograms are matched. However, this 1D sorting trick does not extend to higher dimensions.
- Continuous Case: The push-forward constraint T♯α=β means ∫Yh(y)dβ(y)=∫Xh(T(x))dα(x) for any test function h. A major limitation is that an optimal map T may not exist (e.g., transporting a single Dirac mass to multiple locations).
- Brenier's Theorem: For the L2 cost c(x,y)=∥x−y∥2 in Rd, if α has a density, a unique optimal map T exists and is the gradient of a convex function, T=∇ϕ. This provides valuable structure for the optimal map but can still be computationally challenging to find. It also leads to the Monge-Ampère equation det(∂2ϕ(x))ρβ(∇ϕ(x))=ρα(x) relating the densities ρα,ρβ and the potential ϕ.
- Kantorovich Relaxation: To address the limitations of Monge's formulation, Kantorovich introduced a relaxed problem using "couplings" or "transport plans." A coupling π is a joint probability measure on X×Y with marginals π1=α and π2=β. The problem is to minimize ∫X×Yc(x,y)dπ(x,y) over the set of admissible couplings Π(α,β).
- Discrete Case: For discrete measures α=∑iaiδxi and β=∑jbjδyj, a coupling is a matrix P∈R+n×m where Pij is the mass moved from xi to yj. The marginal constraints are P1m=a and P⊤1n=b. The objective is ∑i,jCijPij. This is a linear programming problem. Efficient algorithms like the Hungarian algorithm (O(n3) for assignment), Auction algorithm, and network simplex (O(n3logn) or faster variants) can solve this. The Birkhoff-von Neumann theorem shows that for the assignment problem (n=m, uniform weights), the extreme points of the feasible set (bistochastic matrices) are permutation matrices, guaranteeing an optimal coupling exists which corresponds to a Monge map (permutation).
- Continuous Case: The Kantorovich problem is an infinite-dimensional linear program. It always has a solution, and the set of couplings Π(α,β) is non-empty (containing α⊗β). The problem can be interpreted probabilistically as minimizing E[c(X,Y)] over random variables (X,Y) with marginal distributions α and β.
Metric Properties and Duality
- Wasserstein Distance: If the cost is c(x,y)=d(x,y)p for a ground distance d on X=Y, the minimum cost Wpp(α,β)=π∈Π(α,β)min∫d(x,y)pdπ(x,y) defines the p-Wasserstein distance. Wp is a true metric, symmetric and satisfying the triangle inequality (proven using a "gluing lemma" for couplings).
- Weak Convergence: A key property of Wp (for p≥1) is that it metrizes the weak convergence of probability measures. This means αk→α weakly if and only if Wp(αk,α)→0. This is crucial in machine learning as it allows comparing distributions that may not have overlapping support (e.g., empirical measures or Dirac masses), unlike divergences like KL. For example, Wp(δxn,δx)=d(xn,x), so convergence of Dirac masses is equivalent to convergence of their locations. This contrasts with the Total Variation (TV) distance, where TV(δxn,δx)=2 for xn=x.
- Dual Problem: The Kantorovich problem has a dual formulation. For continuous measures and cost c, it is maxf,g∫fdα+∫gdβ subject to f(x)+g(y)≤c(x,y) for all x,y. The functions (f,g) are called Kantorovich potentials.
- c-transforms: The constraint f(x)+g(y)≤c(x,y) is equivalent to f(x)≤c(x,y)−g(y) for all x,y, or f(x)≤infy(c(x,y)−g(y))=gcˉ(x), where cˉ(y,x)=c(x,y). Similarly, g(y)≤fc(y)=infx(c(x,y)−f(x)). Optimal potentials are related by f=gcˉ and g=fc on the supports of α and β. For c(x,y)=∥x−y∥2/2, the c-transform relates to the Fenchel-Legendre transform, and optimal potentials are related to convex functions, connecting back to Brenier's theorem.
- W1 Duality: For c(x,y)=d(x,y), the dual problem simplifies significantly. The constraint f(x)+g(y)≤d(x,y) implies f and g are 1-Lipschitz (up to a constant). The dual problem becomes W1(α,β)=f:Lip(f)≤1sup∫fd(α−β), known as the Kantorovich-Rubinstein theorem. This form is very useful for theoretical analysis and connections to other metrics.
- Integral Probability Metrics (IPMs): The W1 dual form is an instance of an IPM, supf∈F∫fd(α−β) for some function class F. Other IPMs include the Maximum Mean Discrepancy (MMD), supf∈B∫fd(α−β) where B is the unit ball of an RKHS defined by a kernel. MMD is often easier to compute (especially for empirical measures) but may not always capture geometric structure as well as Wasserstein distances.
Computational Methods and Regularization
Solving the Kantorovich problem directly using general-purpose LP solvers can be slow for large-scale ML problems (O(N3) or worse for N points). Several approaches have been developed:
- Semi-Discrete OT: If one measure (e.g., β=∑jbjδyj) is discrete and the other (α) has a density, the problem is semi-discrete. The dual problem becomes a finite-dimensional optimization over the dual variables g=(gj)j associated with the discrete points yj: gmax∫(jinfc(x,yj)−gj)dα(x)+j∑gjbj. The gradient of this objective involves integrals over Laguerre cells (generalized Voronoi regions) defined by the points yj and weights gj. This structure allows for specialized algorithms, including stochastic optimization (SGD) by sampling from α, which is practical for large datasets. Semi-discrete OT is closely related to optimal quantization and k-means clustering (L2 cost).
- Entropic Regularization (Sinkhorn): Adding an entropy term ϵKL(π∣α⊗β) to the Kantorovich objective results in a strictly convex problem π∈Π(α,β)min∫c(x,y)dπ(x,y)+ϵKL(π∣α⊗β). Its unique solution πϵ has a special form: πϵ(x,y)=u(x)K(x,y)v(y)dα(x)dβ(y), where K(x,y)=e−c(x,y)/ϵ. This form leads to Sinkhorn's algorithm (also known as Iterative Proportional Fitting), which iteratively scales potentials u,v to satisfy the marginal constraints.
- Sinkhorn Algorithm: For discrete measures, this involves iteratively scaling rows and columns of K=e−C/ϵ.
1 2 3 4 5 6 7 8
K = np.exp(-C / epsilon) u = np.ones(n) / n # Initial guess for scaling vectors v = np.ones(m) / m for _ in range(num_iterations): v = b / (K.T @ u) u = a / (K @ v) P_epsilon = np.diag(u) @ K @ np.diag(v) # Optimal coupling sinkhorn_cost = np.sum(C * P_epsilon) # Regularized cost
- Sinkhorn is much faster than general LPs, with complexity O(nm×iterations), and highly parallelizable (GPU). The number of iterations depends on ϵ, becoming slow for small ϵ.
- As ϵ→0, πϵ converges to an optimal coupling of the original (unregularized) Kantorovich problem (specifically, the one with maximum entropy).
- As ϵ→∞, πϵ→α⊗β (independence), and the regularized cost converges to ∫cdα⊗dβ. This highlights the bias of the regularized cost for large ϵ.
- Sinkhorn Algorithm: For discrete measures, this involves iteratively scaling rows and columns of K=e−C/ϵ.
- Sinkhorn Divergences: To address the entropic bias, debiased versions like Wˉϵ(α,β)=MKϵ(α,β)−21MKϵ(α,α)−21MKϵ(β,β) (where MKϵ is the entropic cost) are proposed. These "Sinkhorn divergences" are non-negative and zero iff α=β. They converge to Wc as ϵ→0 and to a kernel distance as ϵ→∞. They provide a family of distances that interpolate between OT and MMD, often offering a good trade-off between computational cost and geometric properties.
Optimal Transport in Machine Learning Applications
OT provides flexible tools for comparing and transforming probability distributions, finding applications across ML:
- Generative Models:
- GANs: The dual formulation of W1 (and other IPMs/divergences) supf∈F∫fd(α−β) can be used as a training objective. If α is the model distribution and β is the data distribution, the supremum over f can be approximated by a discriminator network fθ, leading to a min-max game: generatorminθmax∫fθdα−∫fθdβ. Wasserstein GANs use W1 by constraining the discriminator's Lipschitz constant.
- Flow Matching: A modern approach to generative modeling defines a time-dependent vector field vt that transports a simple distribution α0 (like Gaussian) to a target α1. This field is derived from an interpolation αt=(Pt)♯π for some coupling π and map Pt. The field vt is shown to be a conditional expectation vt(z)=Eπ([∂tPt](u)∣z=Pt(u)), which minimizes a simple least squares objective. This objective can be minimized by training a neural network vθ(t,z) from samples (Pt(u),[∂tPt](u)) drawn from the interpolation process. Sampling from α1 is then achieved by integrating the ODE X˙t=vθ(t,Xt) starting from X0∼α0.
- Comparing Samples/Datasets: Wp and Sinkhorn divergences are used as distances between empirical distributions. For datasets {xi}i=1n and {yj}j=1m, the distance is computed between n1∑δxi and m1∑δyj. This is useful for measuring differences between datasets, two-sample testing, and evaluating generative models.
- Data Structures and Learning:
- Wasserstein Barycenters: The mean of a set of distributions {αk} in the Wasserstein space is argminα∑kW22(α,αk). This is a convex problem and can be solved using Sinkhorn-based methods. Barycenters are used for averaging, clustering, and dimensionality reduction (like Wasserstein PCA).
- Distribution Regression: Learning models that map inputs to probability distributions (e.g., predicting a histogram from features) can be done by minimizing a loss function involving a Wasserstein distance or Sinkhorn divergence between the predicted and target distributions.
- Domain Adaptation: OT can align distributions from different domains by finding a transport plan or map between them, allowing models trained on a source domain to perform better on a target domain.
- Structured Data Comparison: OT is applied to compare structured data like point clouds, graphs, or images by representing them as distributions and computing OT distances (e.g., Earth Mover's Distance for histograms, Gromov-Wasserstein for metric spaces).
- Optimization and Flows:
- Wasserstein Gradient Flows: OT provides a geometry on the space of measures, allowing definition of gradient flows for functionals f(α). The flow ∂tαt+div(αt∇Wf(αt))=0 describes how αt evolves to minimize f. Examples include the heat equation (flow of entropy) and particle systems (flow of interaction energy).
- Implicit Bias of SGD: OT geometry can analyze the dynamics of optimization algorithms. For example, training two-layer neural networks can be viewed as a Wasserstein flow on the distribution of neuron parameters. The flow dynamics can reveal properties like convergence to global optima under certain conditions.
- Transformer Dynamics: Attention mechanisms in transformers can be modeled as updates to token distributions, leading to PDEs on the space of token measures.
In summary, the notes emphasize that Optimal Transport offers a powerful framework for understanding and manipulating probability distributions in a geometrically intuitive way. While classic formulations were computationally intensive, modern techniques like entropic regularization (Sinkhorn) and stochastic semi-discrete methods make OT practical for large-scale machine learning problems. Its applications range from foundational tasks like comparing distributions and computing averages to advanced generative modeling techniques and analyzing the dynamics of complex deep learning architectures.
Paper to Video (Beta)
No one has generated a video about this paper yet.
Whiteboard
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.
Top Community Prompts
Open Problems
Continue Learning
- How does entropic regularization affect the geometric properties of optimal transport distances, and when might the introduced bias be problematic in practice?
- What are the theoretical and practical differences between Wasserstein distance and Maximum Mean Discrepancy when comparing distributions in machine learning scenarios?
- In what ways have semi-discrete OT methods been leveraged for improvements in large-scale generative modeling or domain adaptation tasks?
- How does the interpretation of gradient flows in Wasserstein space provide insights into the convergence and implicit bias of optimization algorithms such as SGD?
- Find recent papers about Sinkhorn divergences and their applications in machine learning.
Related Papers
- Lecture Notes in Probabilistic Diffusion Models (2023)
- Computational Optimal Transport (2018)
- Large-Scale Optimal Transport and Mapping Estimation (2017)
- Geometric Dataset Distances via Optimal Transport (2020)
- A Survey on Optimal Transport for Machine Learning: Theory and Applications (2021)
- Stochastic Optimization for Large-scale Optimal Transport (2016)
- Statistical optimal transport (2024)
- Multi-Level Optimal Transport for Universal Cross-Tokenizer Knowledge Distillation on Language Models (2024)
- Probably Approximately Correct Labels (2025)
- Data-driven approaches to inverse problems (2025)
Authors (1)
Collections
Sign up for free to add this paper to one or more collections.
Tweets
Sign up for free to view the 3 tweets with 261 likes about this paper.