FlashSinkhorn: Efficient EOT Computation
- FlashSinkhorn is a family of algorithms that efficiently solves entropic-regularized optimal transport problems using optimized GPU streaming and NFFT acceleration.
- It fuses kernel operations and streams matrix tiles to drastically reduce memory load and HBM traffic, overcoming computational bottlenecks.
- Empirical results show significant speedups over traditional methods, enabling scalable transport computations on modern accelerator architectures.
FlashSinkhorn denotes a family of algorithms designed for efficient large-scale computation of entropic-regularized optimal transport (EOT) via Sinkhorn-type iterations, with highly optimized implementations exploiting either hardware-aware GPU streaming and fusion techniques or nonequispaced fast Fourier transform (NFFT) acceleration. Recent advances under the name "FlashSinkhorn" address the bandwidth, memory, and computational bottlenecks that have limited the scaling of EOT solvers, achieving significant gains both on modern GPU architectures (especially A100) and in dense, CPU-bound settings (Ye et al., 3 Feb 2026, Lakshmanan et al., 2022).
1. Entropic Optimal Transport and the Sinkhorn Algorithm
Let , with weights , , and the cost matrix . The entropic-regularized OT problem is
where and is the relative entropy.
Scaling approaches such as Sinkhorn's algorithm alternate between updates to dual potentials and ; the stabilized log-domain variant is \begin{align*} f_i &\gets -\mathrm{LogSumExp}j\left[\frac{g_j - C{ij}}{\varepsilon} + \log b_j\right],\ g_j &\gets -\mathrm{LogSumExp}i\left[\frac{f_i - C{ij}}{\varepsilon} + \log a_i\right]. \end{align*} Memory and compute costs have made these methods impractical for large due to dense interactions.
2. FlashSinkhorn: IO-Aware GPU Algorithms
The FlashSinkhorn method in (Ye et al., 3 Feb 2026) addresses performance limits in GPU-based EOT solvers by recasting log-domain Sinkhorn updates in a form structurally identical to a row-wise LogSumExp over biased dot-product scores: with , , , , and .
This structure enables a direct mapping to Transformer-style attention: LogSumExp reductions of pairwise dot-product matrices with additive biases.
By fusing all computation stages (dot products, biasing, LogSumExp) into a single custom kernel, and streaming matrix tiles through on-chip SRAM, FlashSinkhorn achieves the following workflow:
- Partition and into blocks fitting in SRAM.
- Stream each tile, compute the required scores and accumulate row-wise LogSumExp statistics.
- On completion, update potentials and transfer only summarizing statistics to high-bandwidth memory (HBM).
For block sizes , and on-chip SRAM of size , the entire per-iteration HBM traffic is reduced to
which simplifies to when suffices for the product of the smaller batch count with feature dimension (Ye et al., 3 Feb 2026).
3. Memory Complexity, Bandwidth and Profiling
A comparative summary of resource usage is as follows:
| Implementation | Intermediates | HBM Traffic (per iter) | Total Memory |
|---|---|---|---|
| Tensorized | |||
| KeOps-style online | -- | lower, but low SM | |
| FlashSinkhorn | $0.08$ GB (10 iters, k, ) |
Nsight profiling for FlashSinkhorn on A100 GPUs demonstrates only 3\% memory stalls and 74\% SM utilization, compared to 79\% memory stalls for the tensorized algorithm.
4. NFFT-Accelerated FlashSinkhorn
An alternative FlashSinkhorn algorithm, primarily for CPU or moderate-scale problems, accelerates the Sinkhorn iteration using nonequispaced fast Fourier transforms (NFFTs) (Lakshmanan et al., 2022). The cost kernel is approximated by a -dimensional truncated Fourier series with coefficients computed efficiently by FFT and smooth windowing. Each Sinkhorn iteration reduces to two NFFT-based convolution-type products, allowing runtime per iteration, as opposed to for classical dense-matrix methods.
For typical Gaussian kernels in moderate dimension, window width , oversampling , and bandwidth per dimension suffice to reach machine precision. All kernel interactions are performed implicitly; only scaling vectors and NFFT plans are stored.
5. Streaming Kernels for Transport and Second-Order Methods
FlashSinkhorn (Ye et al., 3 Feb 2026) implements not only potential updates but also transport-application and Hessian–vector product (HVP) operations, using the same streaming/fused kernel paradigm:
- Application of the transport map by streaming softmax-weighted dot products,
- Adjoint transport and Hadamard-weighted variants ,
- Streaming HVP via a decomposition involving Schur complements and on-the-fly conjugate gradient solves for the dual Schur system.
All operations are unified into a minimal kernel set inheriting the bandwidth and memory advantages of the main update phase.
6. Empirical Performance and Scalability
On NVIDIA A100–80 GB with Triton 2.1+/PyTorch 2.5+, FlashSinkhorn demonstrates:
- Synthetic point clouds (, ):
- Forward (dual potentials): up to speedup over KeOps, $3.5$– over tensorized.
- End-to-end forward/backward: up to speedup over KeOps, up to over OTT-JAX.
- HVP (50 CG iterations): $3$– speedup over OTT-JAX, $4$– over KeOps.
- Downstream tasks:
- OTDD on MNIST↔Fashion-MNIST: matches tensorized speed until k, scales to k with GB memory, where tensorized is out-of-memory.
- Gradient flow and Hessian spectra in large (shuffled regression) synthetic settings: enables saddle detection and mode switching, with end-to-end speedup over Adam-only optimization strategies.
For NFFT-accelerated FlashSinkhorn, empirical benchmarks report the solution of image OT problems in 3.8s with sub- kernel-approximation error, far surpassing the dense baseline in both speed and scale. For synthetic points, computation is feasible in ≈60s with 132 MB memory (Lakshmanan et al., 2022).
7. Reproducibility, Implementation, and Parameter Selection
The open-source FlashSinkhorn implementation from (Ye et al., 3 Feb 2026) is available at https://github.com/ot-triton-lab/ot_triton (Apache 2.0 licence). Installation is via
1 |
pip install ot_triton |
Key runtime parameters:
- Regularization
- Floating point: TF32 for forward/backward, FP32 for HVP
- Block sizes chosen so that SRAM usage is maximized under hardware constraints
For NFFT-based FlashSinkhorn (Lakshmanan et al., 2022), recommended parameter choices include per dimension, window width , oversampling , and period for support radius . Machine-precision can be matched to within bias.
Reproducibility is facilitated by fully-documented kernel code, parameter rationale, and consistent hardware-agnostic benchmarks.