- The paper introduces a novel linear-time algorithm for Gromov-Wasserstein distances by exploiting low-rank factorizations in both cost matrices and coupling structures.
- It reduces the cubic complexity of traditional methods to linear complexity by approximating input costs and constraining the coupling matrix.
- The approach demonstrates practical scalability on large-scale datasets such as single-cell genomics and human brain data while preserving alignment accuracy.
The paper "Linear-Time Gromov Wasserstein Distances using Low Rank Couplings and Costs" (2106.01128) addresses the computational bottleneck of the Gromov-Wasserstein (GW) problem, which is widely used for aligning and comparing data from different metric spaces, such as point clouds or distributions living in heterogeneous feature spaces. The standard approach to solving GW approximately, based on iteratively solving entropy-regularized Optimal Transport (OT) problems, suffers from cubic complexity O(n3) in the number of samples n, making it impractical for large datasets. This paper proposes novel methods to reduce the GW computation time, ultimately achieving a linear-time algorithm by exploiting low-rank structures in both the input cost matrices and the coupling matrix.
The standard GW problem seeks a coupling matrix P∈Rn×m between two discrete measures with n and m samples, represented by cost matrices A∈Rn×n and B∈Rm×m that encode the geometry within each space. The objective is to minimize a quadratic function QA,B​(P) that measures the distortion introduced by the coupling P. The standard entropic GW approximation solves this non-convex problem iteratively using Mirror Descent, which boils down to solving a sequence of entropy-regularized OT problems with a synthetic cost matrix Ct​=−4APt−1​B. The computational bottleneck arises from two main steps in each iteration:
- Updating the cost matrix n0, which requires n1 operations.
- Evaluating the GW objective n2, which also costs n3 naively, but can be computed more efficiently using a reformulation in n4 operations.
- Solving the entropy-regularized OT problem n5 using Sinkhorn's algorithm takes n6 operations per iteration of Sinkhorn, but the dominant cost per outer GW iteration remains the n7 cost of updating n8.
The paper tackles this cubic complexity by introducing two independent strategies and then showing how to combine them:
1. Low-rank (Approximated) Costs:
If the input cost matrices n9 and P∈Rn×m0 admit low-rank factorizations, P∈Rn×m1 and P∈Rn×m2 where P∈Rn×m3 and P∈Rn×m4 with P∈Rn×m5 and P∈Rn×m6, the complexity of updating the synthetic cost matrix P∈Rn×m7 can be reduced.
P∈Rn×m8.
Computing this product can be done more efficiently: first compute P∈Rn×m9 in n0 operations, then n1 in n2. If n3 are small constants, this reduces the update cost to n4, which is n5 per outer iteration if n6, or n7 if n8. Similarly, the evaluation of n9 can be sped up to m0.
This strategy is particularly relevant for squared Euclidean distance matrices, where an exact low-rank factorization exists with rank related to the ambient dimension (m1). For general distance matrices, recent work allows for computing low-rank approximations in nearly linear time, enabling this speedup even when an exact factorization is not obvious. Algorithm 2 outlines this "Quadratic Entropic-GW" approach.
2. Low-rank Constraints for Couplings:
Instead of assuming low-rank input costs, the paper proposes constraining the coupling matrix m2 to have a low nonnegative rank. This is achieved by restricting m3 to the form m4, where m5 are matrices satisfying certain marginal constraints (specifically, m6, m7, m8) and m9 is a common intermediate marginal. This factorization implies A∈Rn×n0 has a nonnegative rank at most A∈Rn×n1.
The GW problem is then reformulated as minimizing A∈Rn×n2 over A∈Rn×n3 in the feasible set A∈Rn×n4. This problem is solved using a Mirror Descent scheme w.r.t. the KL divergence in the space of A∈Rn×n5. Each step involves computing generalized kernel matrices A∈Rn×n6 and then solving a barycenter problem efficiently using Dykstra's algorithm (Algorithm 3).
The initialization uses a low-rank approximation of a lower bound based on the squared norms of the rows/columns of A∈Rn×n7 and A∈Rn×n8. This initialization itself can be computed efficiently.
While Dykstra's algorithm for the barycenter step takes A∈Rn×n9 operations per inner iteration, computing the kernel matrices B∈Rm×m0 still involves matrix products like B∈Rm×m1, which require B∈Rm×m2 operations in the general case. Thus, this approach alone reduces the complexity per outer iteration to B∈Rm×m3.
3. Double Low-rank GW:
The key contribution is showing that combining both low-rank strategies yields a linear-time algorithm. If both cost matrices B∈Rm×m4 have low-rank factorizations (B∈Rm×m5) and the coupling B∈Rm×m6 is constrained to be low-rank (B∈Rm×m7), the critical computation B∈Rm×m8 becomes B∈Rm×m9.
This can be computed in QA,B​(P)0 operations by first computing QA,B​(P)1 (in QA,B​(P)2), then QA,B​(P)3 (in QA,B​(P)4), then QA,B​(P)5 (in QA,B​(P)6), etc. More strategically, one can compute intermediate low-rank factors of QA,B​(P)7: QA,B​(P)8 and QA,B​(P)9. The necessary terms for the kernel matrices can then be computed from these factors efficiently.
Specifically, terms like P0 can be computed efficiently under these double low-rank assumptions. For instance, P1. This still seems complex, but the Mirror Descent updates in the low-rank coupling formulation (Algorithm 3) involve terms like P2 (for P3) and P4 (for P5), and P6 (for P7).
Using P8, these terms become:
- P9
- Ct​=−4APt−1​B0
- Ct​=−4APt−1​B1
If Ct​=−4APt−1​B2 and Ct​=−4APt−1​B3, these can be computed in Ct​=−4APt−1​B4 operations per outer iteration. For instance, Ct​=−4APt−1​B5:
- Ct​=−4APt−1​B6: Ct​=−4APt−1​B7
- Ct​=−4APt−1​B8: Ct​=−4APt−1​B9
- n00: n01
- Product of n02 and n03 and n04 involves intermediate n05 operations, leading to n06 overall for n07. This was not the linearization.
The linear time comes from carefully re-evaluating the gradients in the Mirror Descent update (Equation 6 in the paper):
The gradient w.r.t n08 involves n09. With factorizations: n10.
This can be computed as n11. The term in parentheses is n12.
n13 takes n14. n15 takes n16. n17 takes n18.
n19 is n20. n21 is n22. n23 is n24. n25 is n26.
The term n27 is n28, n29 is n30, n31 is n32.
The inner parentheses computation is n33. This is a product of matrices of sizes n34, n35, n36, n37, n38. The central term n39 is n40. n41 is n42. n43 is n44. Final matrix multiplication involves n45 with n46 and n47.
The most expensive matrix products are:
- n48: n49
- n50: n51
- n52: n53
- n54: inner products n55 take n56, then products n57 take n58. Total n59.
- Outer product with n60: n61. This is n62, cost n63.
So the gradient w.r.t Q computation is dominated by n64. If n65 are constants, this is n66. If n67 are n68 or n69, it's still close to linear.
Similarly, evaluating n70.
n71. This can be computed efficiently by recognizing it's related to n72 terms, where n73 and n74. n75 (n76) and n77 (n78). These can be computed in n79. The dot product n80 is n81.
The terms n82 and n83 can also be computed in nearly linear time if n84 and n85 are factorized as n86 etc., exploiting the fact that n87 can also be factorized, though with potentially larger rank n88. Computing n89 costs n90.
Thus, under both low-rank assumptions, the computation per outer iteration becomes linear n91. With n92 small, this is n93.
Implementation and Applications:
The paper provides algorithms for the quadratic (Algorithm 2, 3) and linear (Section 5 combining aspects of 2 and 3) methods.
- Initialization: A warm start is crucial for non-convex optimization. The proposed initialization uses a low-rank OT problem based on the squared norms of the row/column norms of n94 and n95, which can be computed in linear time under low-rank cost assumptions.
- Optimization: Mirror Descent is used. For the low-rank coupling method, each MD step requires solving a barycenter problem using Dykstra's algorithm. The paper shows experimentally that the number of Dykstra iterations doesn't heavily depend on n96, which is favorable.
- Hyperparameters: The method has hyperparameters like the low rank n97 and the step size n98 (or regularization n99 if double regularization is used). The paper explores the sensitivity to P∈Rn×m00 and a lower bound P∈Rn×m01 on entries of P∈Rn×m02, finding the method relatively robust. The choice of P∈Rn×m03 affects the quality of the approximation; ideally, it should relate to the intrinsic dimension or number of clusters in the data.
- Computational Cost: The paper provides clear complexity analyses: P∈Rn×m04 for standard entropic GW, P∈Rn×m05 for quadratic GW (low-rank costs or low-rank couplings separately), and P∈Rn×m06 for linear GW (both low-rank costs and couplings).
- Real-world Applications: The methods are demonstrated on single-cell genomics data (SNAREseq and Splatter) and a human brain dataset (BRAIN). These applications involve aligning point clouds representing cells characterized by different molecular features (e.g., gene expression and chromatin accessibility). The distance metric used is often based on k-NN graphs and shortest paths. While shortest path distance matrices don't automatically admit low-rank factorizations like Euclidean distance, the quadratic version of the algorithm is applicable by simply computing the full distance matrix. The linear version is demonstrated on the BRAIN dataset using squared Euclidean distance after PCA, where the low-rank factorization is available.
- Performance: Experiments show that the proposed LR (quadratic) and Lin LR (linear) methods achieve similar GW loss and downstream task performance (like cell type alignment measured by FOSCTTM) compared to the standard Entropic-GW and MREC baselines, but are orders of magnitude faster, particularly at large scales (P∈Rn×m07). Lin LR is shown to be the only viable method for very large datasets.
- Limitations: The linearity relies on the assumption that the intrinsic dimensionality of the data (reflected in the rank of cost matrices) and the required rank P∈Rn×m08 for the coupling are small relative to P∈Rn×m09. Tuning P∈Rn×m10 might be necessary in practice.
Overall, the paper presents a significant step towards making Gromov-Wasserstein scalable by introducing and demonstrating the effectiveness of low-rank approaches for both the geometry of the input spaces and the structure of the coupling. The combined linear-time method offers a practical way to apply GW to large-scale problems previously inaccessible due to computational constraints.