Papers
Topics
Authors
Recent
Search
2000 character limit reached

FlashSinkhorn: Efficient EOT Computation

Updated 10 February 2026
  • FlashSinkhorn is a family of algorithms that efficiently solves entropic-regularized optimal transport problems using optimized GPU streaming and NFFT acceleration.
  • It fuses kernel operations and streams matrix tiles to drastically reduce memory load and HBM traffic, overcoming computational bottlenecks.
  • Empirical results show significant speedups over traditional methods, enabling scalable transport computations on modern accelerator architectures.

FlashSinkhorn denotes a family of algorithms designed for efficient large-scale computation of entropic-regularized optimal transport (EOT) via Sinkhorn-type iterations, with highly optimized implementations exploiting either hardware-aware GPU streaming and fusion techniques or nonequispaced fast Fourier transform (NFFT) acceleration. Recent advances under the name "FlashSinkhorn" address the bandwidth, memory, and computational bottlenecks that have limited the scaling of EOT solvers, achieving significant gains both on modern GPU architectures (especially A100) and in dense, CPU-bound settings (Ye et al., 3 Feb 2026, Lakshmanan et al., 2022).

1. Entropic Optimal Transport and the Sinkhorn Algorithm

Let X={xi}i=1nRdX = \{x_i\}_{i=1}^n \subset \mathbb{R}^d, Y={yj}j=1mRdY = \{y_j\}_{j=1}^m \subset \mathbb{R}^d with weights aΔna \in \Delta^n, bΔmb \in \Delta^m, and the cost matrix Cij=xiyj22C_{ij} = \|x_i - y_j\|_2^2. The entropic-regularized OT problem is

OTε(a,b;C)=minPΠ(a,b)C,P+εKL(Pab),\mathrm{OT}_\varepsilon(a, b; C) = \min_{P \in \Pi(a, b)} \langle C, P \rangle + \varepsilon\, \mathrm{KL}(P \| a \otimes b),

where Π(a,b)={P0:P1m=a,PT1n=b}\Pi(a, b) = \{P \ge 0: P \mathbf{1}_m = a,\, P^T \mathbf{1}_n = b\} and KL(Pab)\mathrm{KL}(P\|a\otimes b) is the relative entropy.

Scaling approaches such as Sinkhorn's algorithm alternate between updates to dual potentials fRnf \in \mathbb{R}^n and gRmg \in \mathbb{R}^m; the stabilized log-domain variant is \begin{align*} f_i &\gets -\mathrm{LogSumExp}j\left[\frac{g_j - C{ij}}{\varepsilon} + \log b_j\right],\ g_j &\gets -\mathrm{LogSumExp}i\left[\frac{f_i - C{ij}}{\varepsilon} + \log a_i\right]. \end{align*} Memory and compute costs have made these methods impractical for large n,mn, m due to dense n×mn\times m interactions.

2. FlashSinkhorn: IO-Aware GPU Algorithms

The FlashSinkhorn method in (Ye et al., 3 Feb 2026) addresses performance limits in GPU-based EOT solvers by recasting log-domain Sinkhorn updates in a form structurally identical to a row-wise LogSumExp over biased dot-product scores: f^LSErow(SX(g^)),SX(g^)=QKT+1n(g^+δ)Tε\hat f \gets -\,\mathrm{LSE}_{\rm row}(S_X(\hat g)), \quad S_X(\hat g) = \frac{Q K^T + \mathbf{1}_n (\hat g + \delta)^T}{\varepsilon} with Q=2XQ = \sqrt{2} X, K=2YK = \sqrt{2} Y, δ=εlogb\delta = \varepsilon \log b, f^=fα\hat f = f - \alpha, and αi=xi2\alpha_i = \|x_i\|^2.

This structure enables a direct mapping to Transformer-style attention: LogSumExp reductions of pairwise dot-product matrices with additive biases.

By fusing all computation stages (dot products, biasing, LogSumExp) into a single custom kernel, and streaming matrix tiles through on-chip SRAM, FlashSinkhorn achieves the following workflow:

  • Partition QQ and KK into blocks fitting in SRAM.
  • Stream each tile, compute the required scores and accumulate row-wise LogSumExp statistics.
  • On completion, update potentials and transfer only summarizing statistics to high-bandwidth memory (HBM).

For block sizes BNB_N, BMB_M and on-chip SRAM of size MM, the entire per-iteration HBM traffic is reduced to

Θ(nd+md+nmd2M),\Theta(nd + md + \tfrac{n m d^2}{M}),

which simplifies to Θ(nd+md)\Theta(nd + md) when MM suffices for the product of the smaller batch count with feature dimension (Ye et al., 3 Feb 2026).

3. Memory Complexity, Bandwidth and Profiling

A comparative summary of resource usage is as follows:

Implementation Intermediates HBM Traffic (per iter) Total Memory
Tensorized O(nm)O(nm) Θ(nm)\Theta(nm) O(nm)O(nm)
KeOps-style online -- Θ(nm)\Theta(nm) lower, but low SM
FlashSinkhorn O((n+m)d)O((n+m)d) $0.08$ GB (10 iters, n=m=10n=m=10k, d=64d=64) O((n+m)d)O((n+m)d)

Nsight profiling for FlashSinkhorn on A100 GPUs demonstrates only 3\% memory stalls and 74\% SM utilization, compared to 79\% memory stalls for the tensorized algorithm.

4. NFFT-Accelerated FlashSinkhorn

An alternative FlashSinkhorn algorithm, primarily for CPU or moderate-scale problems, accelerates the Sinkhorn iteration using nonequispaced fast Fourier transforms (NFFTs) (Lakshmanan et al., 2022). The cost kernel kij=exp(λxiyj2)k_{ij} = \exp(-\lambda \|x_i - y_j\|^2) is approximated by a dd-dimensional truncated Fourier series with coefficients computed efficiently by FFT and smooth windowing. Each Sinkhorn iteration reduces to two NFFT-based convolution-type products, allowing O(nlogn)O(n\log n) runtime per iteration, as opposed to O(n2)O(n^2) for classical dense-matrix methods.

For typical Gaussian kernels in moderate dimension, window width p6p\sim 6, oversampling σ2\sigma \sim 2, and bandwidth N=128N = 128 per dimension suffice to reach machine precision. All kernel interactions are performed implicitly; only scaling vectors and NFFT plans are stored.

5. Streaming Kernels for Transport and Second-Order Methods

FlashSinkhorn (Ye et al., 3 Feb 2026) implements not only potential updates but also transport-application and Hessian–vector product (HVP) operations, using the same streaming/fused kernel paradigm:

  • Application of the transport map PVP V by streaming softmax-weighted dot products,
  • Adjoint transport PTUP^T U and Hadamard-weighted variants (PW)V(P \odot W)V,
  • Streaming HVP via a decomposition involving Schur complements and on-the-fly conjugate gradient solves for the dual Schur system.

All operations are unified into a minimal kernel set inheriting the bandwidth and memory advantages of the main update phase.

6. Empirical Performance and Scalability

On NVIDIA A100–80 GB with Triton 2.1+/PyTorch 2.5+, FlashSinkhorn demonstrates:

  • Synthetic point clouds (n[5k,50k]n \in [5k, 50k], d[4,1024]d \in [4, 1024]):
    • Forward (dual potentials): up to 32×32\times speedup over KeOps, $3.5$–12×12\times over tensorized.
    • End-to-end forward/backward: up to 161×161\times speedup over KeOps, up to 5×5\times over OTT-JAX.
    • HVP (50 CG iterations): $3$–6×6\times speedup over OTT-JAX, $4$–26×26\times over KeOps.
  • Downstream tasks:
    • OTDD on MNIST↔Fashion-MNIST: matches tensorized speed until n20n \sim 20k, scales to n=60n = 60k with <1<1GB memory, where tensorized is out-of-memory.
    • Gradient flow and Hessian spectra in large (shuffled regression) synthetic settings: enables saddle detection and mode switching, with 2.8×2.8\times end-to-end speedup over Adam-only optimization strategies.

For NFFT-accelerated FlashSinkhorn, empirical benchmarks report the solution of n=m=262,144n=m=262,144 image OT problems in 3.8s with sub-10810^{-8} kernel-approximation error, far surpassing the dense baseline in both speed and scale. For n=m=107n=m=10^7 synthetic points, computation is feasible in ≈60s with 132 MB memory (Lakshmanan et al., 2022).

7. Reproducibility, Implementation, and Parameter Selection

The open-source FlashSinkhorn implementation from (Ye et al., 3 Feb 2026) is available at https://github.com/ot-triton-lab/ot_triton (Apache 2.0 licence). Installation is via

1
pip install ot_triton
Tutorial notebooks reproduce all synthetic and real-data benchmarks. Benchmark timings include warmup and are fixed with respect to iteration count for fairness.

Key runtime parameters:

  • Regularization ε=0.1\varepsilon = 0.1
  • Floating point: TF32 for forward/backward, FP32 for HVP
  • Block sizes chosen so that SRAM usage is maximized under hardware constraints

For NFFT-based FlashSinkhorn (Lakshmanan et al., 2022), recommended parameter choices include N=128256N = 128\text{--}256 per dimension, window width p=68p=6\text{--}8, oversampling σ=23\sigma=2\text{--}3, and period h=2Lh=2L for support radius LL. Machine-precision can be matched to within O(εNFFT/λ)O(\varepsilon_{\rm NFFT}/\lambda) bias.

Reproducibility is facilitated by fully-documented kernel code, parameter rationale, and consistent hardware-agnostic benchmarks.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (2)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to FlashSinkhorn.