Papers
Topics
Authors
Recent
Search
2000 character limit reached

Optimal Transport for Machine Learners

Published 10 May 2025 in stat.ML, cs.AI, and math.OC | (2505.06589v1)

Abstract: Optimal Transport is a foundational mathematical theory that connects optimization, partial differential equations, and probability. It offers a powerful framework for comparing probability distributions and has recently become an important tool in machine learning, especially for designing and evaluating generative models. These course notes cover the fundamental mathematical aspects of OT, including the Monge and Kantorovich formulations, Brenier's theorem, the dual and dynamic formulations, the Bures metric on Gaussian distributions, and gradient flows. It also introduces numerical methods such as linear programming, semi-discrete solvers, and entropic regularization. Applications in machine learning include topics like training neural networks via gradient flows, token dynamics in transformers, and the structure of GANs and diffusion models. These notes focus primarily on mathematical content rather than deep learning techniques.

Summary

  • The paper introduces a unified framework combining Monge and Kantorovich formulations to efficiently compute optimal transport plans between probability distributions.
  • It presents scalable computational methods, including entropic regularization via the Sinkhorn algorithm, to handle large-scale machine learning problems.
  • The study demonstrates how optimal transport enhances generative modeling and domain adaptation by providing meaningful geometric metrics for comparing complex data.

Optimal Transport (OT) is a mathematical framework for comparing probability distributions or measures by finding the most efficient way to move mass from one distribution to another. Originating from Monge's problem in the 18th century and formalized by Kantorovich in the 20th century as a linear programming problem, OT has recently gained significant traction in machine learning for its ability to provide a geometrically meaningful distance between distributions, particularly useful for structured data like images, text, or point clouds where traditional measures like Kullback-Leibler divergence fall short because they are infinite for non-overlapping supports. These notes (2505.06589) cover the fundamental mathematical concepts, computational methods, and applications of OT relevant to machine learning practitioners.

Mathematical Foundations

The core idea of OT is finding a "transport plan" that minimizes the total cost of moving mass from a source distribution α\alpha to a target distribution β\beta, given a cost function c(x,y)c(x,y) between points xx in the source space X\mathcal{X} and yy in the target space Y\mathcal{Y}.

  1. Monge Problem: This is the original formulation, seeking a map T:XYT: \mathcal{X} \to \mathcal{Y} such that Tα=βT_\sharp \alpha = \beta (meaning TT pushes forward α\alpha to β\beta) and minimizes Xc(x,T(x))dα(x)\int_{\mathcal{X}} c(x, T(x)) d\alpha(x).
    • Discrete Case: For two sets of points {xi}i=1n\{x_i\}_{i=1}^n and {yj}j=1m\{y_j\}_{j=1}^m, the simplest case is matching n=mn=m points with uniform mass $1/n$. This is the optimal assignment problem. Naively, it has n!n! permutations. For 1D points and convex costs c(x,y)=h(xy)c(x,y)=h(|x-y|), the optimal map is sorting the points, leading to an O(nlogn)O(n \log n) solution. A practical application is grayscale histogram equalization in image processing, where the CDFs of the source and target image histograms are matched. However, this 1D sorting trick does not extend to higher dimensions.
    • Continuous Case: The push-forward constraint Tα=βT_\sharp \alpha = \beta means Yh(y)dβ(y)=Xh(T(x))dα(x)\int_{\mathcal{Y}} h(y) d\beta(y) = \int_{\mathcal{X}} h(T(x)) d\alpha(x) for any test function hh. A major limitation is that an optimal map TT may not exist (e.g., transporting a single Dirac mass to multiple locations).
    • Brenier's Theorem: For the L2L^2 cost c(x,y)=xy2c(x,y) = \|x-y\|^2 in Rd\mathbb{R}^d, if α\alpha has a density, a unique optimal map TT exists and is the gradient of a convex function, T=ϕT = \nabla \phi. This provides valuable structure for the optimal map but can still be computationally challenging to find. It also leads to the Monge-Ampère equation det(2ϕ(x))ρβ(ϕ(x))=ρα(x)\det(\partial^2 \phi(x)) \rho_\beta(\nabla \phi(x)) = \rho_\alpha(x) relating the densities ρα,ρβ\rho_\alpha, \rho_\beta and the potential ϕ\phi.
  2. Kantorovich Relaxation: To address the limitations of Monge's formulation, Kantorovich introduced a relaxed problem using "couplings" or "transport plans." A coupling π\pi is a joint probability measure on X×Y\mathcal{X} \times \mathcal{Y} with marginals π1=α\pi_1 = \alpha and π2=β\pi_2 = \beta. The problem is to minimize X×Yc(x,y)dπ(x,y)\int_{\mathcal{X} \times \mathcal{Y}} c(x,y) d\pi(x,y) over the set of admissible couplings Π(α,β)\Pi(\alpha, \beta).
    • Discrete Case: For discrete measures α=iaiδxi\alpha = \sum_i a_i \delta_{x_i} and β=jbjδyj\beta = \sum_j b_j \delta_{y_j}, a coupling is a matrix PR+n×m\mathbf{P} \in \mathbb{R}_+^{n \times m} where Pij\mathbf{P}_{ij} is the mass moved from xix_i to yjy_j. The marginal constraints are P1m=a\mathbf{P} \mathbf{1}_m = \mathbf{a} and P1n=b\mathbf{P}^\top \mathbf{1}_n = \mathbf{b}. The objective is i,jCijPij\sum_{i,j} \mathbf{C}_{ij} \mathbf{P}_{ij}. This is a linear programming problem. Efficient algorithms like the Hungarian algorithm (O(n3)O(n^3) for assignment), Auction algorithm, and network simplex (O(n3logn)O(n^3 \log n) or faster variants) can solve this. The Birkhoff-von Neumann theorem shows that for the assignment problem (n=mn=m, uniform weights), the extreme points of the feasible set (bistochastic matrices) are permutation matrices, guaranteeing an optimal coupling exists which corresponds to a Monge map (permutation).
    • Continuous Case: The Kantorovich problem is an infinite-dimensional linear program. It always has a solution, and the set of couplings Π(α,β)\Pi(\alpha, \beta) is non-empty (containing αβ\alpha \otimes \beta). The problem can be interpreted probabilistically as minimizing E[c(X,Y)]\mathbb{E}[c(X,Y)] over random variables (X,Y)(X,Y) with marginal distributions α\alpha and β\beta.

Metric Properties and Duality

  1. Wasserstein Distance: If the cost is c(x,y)=d(x,y)pc(x,y) = d(x,y)^p for a ground distance dd on X=Y\mathcal{X}=\mathcal{Y}, the minimum cost Wpp(α,β)=minπΠ(α,β)d(x,y)pdπ(x,y)W_p^p(\alpha, \beta) = \min_{\pi \in \Pi(\alpha, \beta)} \int d(x,y)^p d\pi(x,y) defines the pp-Wasserstein distance. WpW_p is a true metric, symmetric and satisfying the triangle inequality (proven using a "gluing lemma" for couplings).
  2. Weak Convergence: A key property of WpW_p (for p1p \ge 1) is that it metrizes the weak convergence of probability measures. This means αkα\alpha_k \to \alpha weakly if and only if Wp(αk,α)0W_p(\alpha_k, \alpha) \to 0. This is crucial in machine learning as it allows comparing distributions that may not have overlapping support (e.g., empirical measures or Dirac masses), unlike divergences like KL. For example, Wp(δxn,δx)=d(xn,x)W_p(\delta_{x_n}, \delta_x) = d(x_n, x), so convergence of Dirac masses is equivalent to convergence of their locations. This contrasts with the Total Variation (TV) distance, where TV(δxn,δx)=2TV(\delta_{x_n}, \delta_x) = 2 for xnxx_n \neq x.
  3. Dual Problem: The Kantorovich problem has a dual formulation. For continuous measures and cost cc, it is maxf,gfdα+gdβ\max_{f,g} \int f d\alpha + \int g d\beta subject to f(x)+g(y)c(x,y)f(x) + g(y) \le c(x,y) for all x,yx,y. The functions (f,g)(f,g) are called Kantorovich potentials.
    • cc-transforms: The constraint f(x)+g(y)c(x,y)f(x) + g(y) \le c(x,y) is equivalent to f(x)c(x,y)g(y)f(x) \le c(x,y) - g(y) for all x,yx,y, or f(x)infy(c(x,y)g(y))=gcˉ(x)f(x) \le \inf_y (c(x,y) - g(y)) = g^{\bar c}(x), where cˉ(y,x)=c(x,y)\bar c(y,x)=c(x,y). Similarly, g(y)fc(y)=infx(c(x,y)f(x))g(y) \le f^c(y) = \inf_x (c(x,y) - f(x)). Optimal potentials are related by f=gcˉf = g^{\bar c} and g=fcg = f^c on the supports of α\alpha and β\beta. For c(x,y)=xy2/2c(x,y) = \|x-y\|^2/2, the cc-transform relates to the Fenchel-Legendre transform, and optimal potentials are related to convex functions, connecting back to Brenier's theorem.
    • W1W_1 Duality: For c(x,y)=d(x,y)c(x,y)=d(x,y), the dual problem simplifies significantly. The constraint f(x)+g(y)d(x,y)f(x)+g(y) \le d(x,y) implies ff and gg are 1-Lipschitz (up to a constant). The dual problem becomes W1(α,β)=supf:Lip(f)1fd(αβ)W_1(\alpha, \beta) = \sup_{f: \text{Lip}(f) \le 1} \int f d(\alpha-\beta), known as the Kantorovich-Rubinstein theorem. This form is very useful for theoretical analysis and connections to other metrics.
    • Integral Probability Metrics (IPMs): The W1W_1 dual form is an instance of an IPM, supfFfd(αβ)\sup_{f \in \mathcal{F}} \int f d(\alpha-\beta) for some function class F\mathcal{F}. Other IPMs include the Maximum Mean Discrepancy (MMD), supfBfd(αβ)\sup_{f \in \mathcal{B}} \int f d(\alpha-\beta) where B\mathcal{B} is the unit ball of an RKHS defined by a kernel. MMD is often easier to compute (especially for empirical measures) but may not always capture geometric structure as well as Wasserstein distances.

Computational Methods and Regularization

Solving the Kantorovich problem directly using general-purpose LP solvers can be slow for large-scale ML problems (O(N3)O(N^3) or worse for NN points). Several approaches have been developed:

  1. Semi-Discrete OT: If one measure (e.g., β=jbjδyj\beta = \sum_j b_j \delta_{y_j}) is discrete and the other (α\alpha) has a density, the problem is semi-discrete. The dual problem becomes a finite-dimensional optimization over the dual variables g=(gj)j\mathbf{g}=(g_j)_j associated with the discrete points yjy_j: maxg(infjc(x,yj)gj)dα(x)+jgjbj\max_{\mathbf{g}} \int (\inf_j c(x, y_j) - g_j) d\alpha(x) + \sum_j g_j b_j. The gradient of this objective involves integrals over Laguerre cells (generalized Voronoi regions) defined by the points yjy_j and weights gjg_j. This structure allows for specialized algorithms, including stochastic optimization (SGD) by sampling from α\alpha, which is practical for large datasets. Semi-discrete OT is closely related to optimal quantization and kk-means clustering (L2L^2 cost).
  2. Entropic Regularization (Sinkhorn): Adding an entropy term ϵKL(παβ)\epsilon \text{KL}(\pi | \alpha \otimes \beta) to the Kantorovich objective results in a strictly convex problem minπΠ(α,β)c(x,y)dπ(x,y)+ϵKL(παβ)\min_{\pi \in \Pi(\alpha,\beta)} \int c(x,y) d\pi(x,y) + \epsilon \text{KL}(\pi | \alpha \otimes \beta). Its unique solution πϵ\pi_\epsilon has a special form: πϵ(x,y)=u(x)K(x,y)v(y)dα(x)dβ(y)\pi_\epsilon(x,y) = u(x) K(x,y) v(y) d\alpha(x) d\beta(y), where K(x,y)=ec(x,y)/ϵK(x,y) = e^{-c(x,y)/\epsilon}. This form leads to Sinkhorn's algorithm (also known as Iterative Proportional Fitting), which iteratively scales potentials u,vu, v to satisfy the marginal constraints.
    • Sinkhorn Algorithm: For discrete measures, this involves iteratively scaling rows and columns of K=eC/ϵ\mathbf{K} = e^{-\mathbf{C}/\epsilon}.
      1
      2
      3
      4
      5
      6
      7
      8
      
      K = np.exp(-C / epsilon)
      u = np.ones(n) / n # Initial guess for scaling vectors
      v = np.ones(m) / m
      for _ in range(num_iterations):
          v = b / (K.T @ u)
          u = a / (K @ v)
      P_epsilon = np.diag(u) @ K @ np.diag(v) # Optimal coupling
      sinkhorn_cost = np.sum(C * P_epsilon) # Regularized cost
    • Sinkhorn is much faster than general LPs, with complexity O(nm×iterations)O(nm \times \text{iterations}), and highly parallelizable (GPU). The number of iterations depends on ϵ\epsilon, becoming slow for small ϵ\epsilon.
    • As ϵ0\epsilon \to 0, πϵ\pi_\epsilon converges to an optimal coupling of the original (unregularized) Kantorovich problem (specifically, the one with maximum entropy).
    • As ϵ\epsilon \to \infty, πϵαβ\pi_\epsilon \to \alpha \otimes \beta (independence), and the regularized cost converges to cdαdβ\int c d\alpha \otimes d\beta. This highlights the bias of the regularized cost for large ϵ\epsilon.
  3. Sinkhorn Divergences: To address the entropic bias, debiased versions like Wˉϵ(α,β)=MKϵ(α,β)12MKϵ(α,α)12MKϵ(β,β)\bar{W}_\epsilon(\alpha, \beta) = \text{MK}^\epsilon(\alpha, \beta) - \frac{1}{2}\text{MK}^\epsilon(\alpha, \alpha) - \frac{1}{2}\text{MK}^\epsilon(\beta, \beta) (where MKϵ^\epsilon is the entropic cost) are proposed. These "Sinkhorn divergences" are non-negative and zero iff α=β\alpha=\beta. They converge to WcW_c as ϵ0\epsilon \to 0 and to a kernel distance as ϵ\epsilon \to \infty. They provide a family of distances that interpolate between OT and MMD, often offering a good trade-off between computational cost and geometric properties.

Optimal Transport in Machine Learning Applications

OT provides flexible tools for comparing and transforming probability distributions, finding applications across ML:

  1. Generative Models:
    • GANs: The dual formulation of W1W_1 (and other IPMs/divergences) supfFfd(αβ)\sup_{f \in \mathcal{F}} \int f d(\alpha-\beta) can be used as a training objective. If α\alpha is the model distribution and β\beta is the data distribution, the supremum over ff can be approximated by a discriminator network fθf_\theta, leading to a min-max game: mingeneratormaxθfθdαfθdβ\min_{\text{generator}} \max_{\theta} \int f_\theta d\alpha - \int f_\theta d\beta. Wasserstein GANs use W1W_1 by constraining the discriminator's Lipschitz constant.
    • Flow Matching: A modern approach to generative modeling defines a time-dependent vector field vtv_t that transports a simple distribution α0\alpha_0 (like Gaussian) to a target α1\alpha_1. This field is derived from an interpolation αt=(Pt)π\alpha_t = (P_t)_\sharp \pi for some coupling π\pi and map PtP_t. The field vtv_t is shown to be a conditional expectation vt(z)=Eπ([tPt](u)z=Pt(u))v_t(z) = \mathbb{E}_\pi([\partial_t P_t](u) | z = P_t(u)), which minimizes a simple least squares objective. This objective can be minimized by training a neural network vθ(t,z)v_\theta(t, z) from samples (Pt(u),[tPt](u))(P_t(u), [\partial_t P_t](u)) drawn from the interpolation process. Sampling from α1\alpha_1 is then achieved by integrating the ODE X˙t=vθ(t,Xt)\dot{X}_t = v_\theta(t, X_t) starting from X0α0X_0 \sim \alpha_0.
  2. Comparing Samples/Datasets: WpW_p and Sinkhorn divergences are used as distances between empirical distributions. For datasets {xi}i=1n\{x_i\}_{i=1}^n and {yj}j=1m\{y_j\}_{j=1}^m, the distance is computed between 1nδxi\frac{1}{n}\sum \delta_{x_i} and 1mδyj\frac{1}{m}\sum \delta_{y_j}. This is useful for measuring differences between datasets, two-sample testing, and evaluating generative models.
  3. Data Structures and Learning:
    • Wasserstein Barycenters: The mean of a set of distributions {αk}\{\alpha_k\} in the Wasserstein space is argminαkW22(α,αk)\arg\min_\alpha \sum_k W_2^2(\alpha, \alpha_k). This is a convex problem and can be solved using Sinkhorn-based methods. Barycenters are used for averaging, clustering, and dimensionality reduction (like Wasserstein PCA).
    • Distribution Regression: Learning models that map inputs to probability distributions (e.g., predicting a histogram from features) can be done by minimizing a loss function involving a Wasserstein distance or Sinkhorn divergence between the predicted and target distributions.
    • Domain Adaptation: OT can align distributions from different domains by finding a transport plan or map between them, allowing models trained on a source domain to perform better on a target domain.
    • Structured Data Comparison: OT is applied to compare structured data like point clouds, graphs, or images by representing them as distributions and computing OT distances (e.g., Earth Mover's Distance for histograms, Gromov-Wasserstein for metric spaces).
  4. Optimization and Flows:
    • Wasserstein Gradient Flows: OT provides a geometry on the space of measures, allowing definition of gradient flows for functionals f(α)f(\alpha). The flow tαt+div(αtWf(αt))=0\partial_t \alpha_t + \text{div}(\alpha_t \nabla_W f(\alpha_t)) = 0 describes how αt\alpha_t evolves to minimize ff. Examples include the heat equation (flow of entropy) and particle systems (flow of interaction energy).
    • Implicit Bias of SGD: OT geometry can analyze the dynamics of optimization algorithms. For example, training two-layer neural networks can be viewed as a Wasserstein flow on the distribution of neuron parameters. The flow dynamics can reveal properties like convergence to global optima under certain conditions.
    • Transformer Dynamics: Attention mechanisms in transformers can be modeled as updates to token distributions, leading to PDEs on the space of token measures.

In summary, the notes emphasize that Optimal Transport offers a powerful framework for understanding and manipulating probability distributions in a geometrically intuitive way. While classic formulations were computationally intensive, modern techniques like entropic regularization (Sinkhorn) and stochastic semi-discrete methods make OT practical for large-scale machine learning problems. Its applications range from foundational tasks like comparing distributions and computing averages to advanced generative modeling techniques and analyzing the dynamics of complex deep learning architectures.

Paper to Video (Beta)

Whiteboard

Authors (1)

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 3 tweets with 261 likes about this paper.