Papers
Topics
Authors
Recent
Search
2000 character limit reached

Geodesic Mode Connectivity

Published 24 Aug 2023 in cs.LG and stat.ML | (2308.12666v1)

Abstract: Mode connectivity is a phenomenon where trained models are connected by a path of low loss. We reframe this in the context of Information Geometry, where neural networks are studied as spaces of parameterized distributions with curved geometry. We hypothesize that shortest paths in these spaces, known as geodesics, correspond to mode-connecting paths in the loss landscape. We propose an algorithm to approximate geodesics and demonstrate that they achieve mode connectivity.

Citations (2)

Summary

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Explain it Like I'm 14

Overview

This paper looks at how to connect two trained neural networks with a smooth path where the models still work well along the way. Imagine the “performance” of a model as a landscape with hills and valleys: good models sit in low valleys (low loss), and bad models sit on high hills (high loss). The authors suggest that the best way to travel between two good models is to follow a shortest path in a special curved space of model predictions. They call these shortest paths “geodesics,” and they show an algorithm that finds such paths that stay in low-loss areas.

Objectives

The paper aims to answer two simple questions:

  1. Can we think about neural networks not just as numbers, but as points in a curved space of their predictions, and use shortest paths in that space to connect good models?
  2. If we try to follow these shortest paths (geodesics), will we actually get a low-loss route between two trained models, even when straight-line paths fail?

Methods and Approach

To make this idea practical, the authors do the following:

  • They treat each neural network as a “distribution” over outputs: for any input image, the network predicts probabilities for each possible label. This turns the set of all networks into a curved space, because the way we measure changes in predictions isn’t flat like a sheet of paper—it’s more like the surface of a globe.
  • A geodesic is the shortest path in that curved space. On a globe, the shortest path between two cities is along a great circle, which looks curved on a flat map. Similarly, the best path between two networks can look curved if you just look at the raw weights, but it’s “short” in terms of how the predictions change.
  • They need a way to measure how different two networks’ predictions are. For this, they use the Jensen–Shannon Divergence (JSD), which you can think of as a score for “how far apart” two sets of probability predictions are. Lower JSD means the predictions are more similar.
  • They create a path between two trained networks by placing several “checkpoints” (intermediate models) between them. Instead of leaving these checkpoints on a straight line, they adjust them to make the total JSD along the path as small as possible. This nudges the path into a curved route that better follows a geodesic in the prediction space.

Here is the high-level procedure:

  • Start with two trained models.
  • Linearly interpolate between them to create a sequence of intermediate models (checkpoints).
  • Keep the ends fixed, and move the intermediate models so that the sum of JSD between each neighboring pair is minimized.
  • Use only the training images to measure prediction differences; no labels are needed for this step.

The authors also test a known trick called “weight matching,” which reorders parts of a network to line up similar features. This can sometimes make straight-line paths work, but the paper focuses on cases where straight lines still fail.

Main Findings

  • On a standard image classification task (ResNet20 on CIFAR-10), straight-line paths between two good models often travel through high-loss regions, meaning the models perform poorly along the way.
  • After applying their geodesic-based optimization, the entire path becomes low-loss. In other words, every checkpoint model along the path performs well—this is “mode connectivity.”
  • Importantly, they show this works even for narrower networks, where previous linear methods struggled. The geodesic path can be longer in ordinary weight space, but it is shorter in the prediction space, which matters for keeping loss low.

Why This Matters

  • It changes how we think about model training: instead of isolated “good spots,” the landscape looks like wide connected valleys where many good solutions are linked.
  • It suggests a better way to combine or transition between models without breaking performance, which could be useful for:
    • Model ensembling or merging different trained models
    • Continual learning, where a model evolves over time
    • Federated learning, where multiple models trained in different places are combined

Implications and Impact

If geodesic paths reliably connect good models, we can:

  • Move between solutions without losing performance, potentially making training more flexible and robust.
  • Understand the “shape” of the model space better, helping researchers design algorithms that avoid high-loss regions.
  • Explore new ways to share or combine models across teams and devices, since we can safely travel between them along low-loss routes.

The authors’ algorithm is a step toward making these ideas practical, especially when simpler methods like straight-line interpolation don’t work. Future work could compare this approach more deeply with other path-finding methods or provide stronger theory explaining when and why geodesic paths guarantee low loss.

Knowledge Gaps

Knowledge gaps, limitations, and open questions

Below is a concise list of unresolved issues, uncertainties, and missing analyses that the paper leaves open for future work:

  • Lack of theory: no proof that geodesics in the Fisher–Rao (FR) information geometry necessarily correspond to low-loss (mode-connecting) paths between SGD solutions.
  • Conditions for validity: unclear under what assumptions on model class, data distribution, or training procedure geodesics are guaranteed to avoid high-loss regions.
  • Uniqueness and multiplicity: no analysis of whether geodesics between two trained solutions are unique, and if multiple geodesics exist, whether all are mode-connecting.
  • Convergence guarantees: the discrete energy minimization over path points has no convergence or optimality guarantees to true geodesics in the FR geometry.
  • Discretization bias: no study of how the number of path points N, spacing, or parameterization affects approximation error, path quality, or stability.
  • Sensitivity to initialization: the method initializes a straight-line path after weight matching; necessity and impact of this pre-alignment (vs. random or alternative initializations) are not evaluated.
  • Necessity of permutation alignment: unclear whether geodesic optimization can succeed without prior weight matching or how sensitive it is to imperfect alignment.
  • Metric fidelity: the use of summed finite JSD increments as a proxy for the FR energy is only justified infinitesimally; finite-step approximation error is unquantified.
  • Choice of divergence: no comparison between JSD and other divergences/metrics (e.g., symmetric KL, Hellinger, Wasserstein) for defining or approximating geodesics.
  • Empirical estimation details: the paper does not specify how JSD over the joint p(x, ŷ; θ) is estimated in practice (batching, smoothing, label-space averaging), nor its variance.
  • Sample complexity: no analysis of how many inputs are needed to reliably estimate JSD for stable optimization, or how batch size affects outcomes.
  • Computational and memory cost: no complexity analysis or profiling for optimizing N intermediate models, especially for larger architectures or datasets.
  • Scalability: no experiments beyond ResNet20 on CIFAR-10; applicability to larger models (e.g., ResNet50, Transformers) or datasets (e.g., ImageNet) is untested.
  • Architectural breadth: results rely on LayerNorm (to preserve permutation symmetry); behavior with BatchNorm (common in practice) and methods to handle its symmetries are not explored.
  • Generality across training regimes: paths between models trained with different optimizers, schedules, regularization, data augmentations, or label noise are not evaluated.
  • Robustness across seeds: no statistics over multiple independent pairs/models to assess success rates, variability, or failure modes.
  • Comparison to baselines: no quantitative comparison to established mode connectivity/path methods (e.g., Garipov et al.’s curves, SWA/SWAG, Bezier paths, LMC when successful).
  • Path quality metrics: beyond loss curves, there is no reporting of path length in the FR geometry, Euclidean length, or curvature; trade-offs are unquantified.
  • Generalization along the path: only loss is shown; accuracy, calibration, and robustness (e.g., to distribution shift or adversarial perturbations) along the path are not assessed.
  • Overfitting risk: since unlabeled training inputs are used to optimize the path, potential overfitting of the path to the training distribution (vs. generalization) is not investigated.
  • Constant-velocity property: the energy formulation implies constant-speed geodesics; constancy of segment-wise JSD increments is not validated empirically.
  • Degenerate solutions: no constraints or diagnostics to prevent collapse of intermediate points (e.g., path points drifting toward endpoints or each other), and no analysis of such degeneracies.
  • Endpoint diversity: only models trained from different random seeds are considered; paths between models from different training epochs, objectives, or data subsets are not explored.
  • Cross-architecture connectivity: feasibility of geodesic paths between different architectures (e.g., ResNet ↔ DenseNet) remains unaddressed.
  • Regression and other tasks: the approach is only demonstrated for classification; extension to regression, structured prediction, or generative modeling is untested.
  • Theoretical linkage to loss: no formal connection is provided between minimizing FR energy (via JSD) and bounding task loss (e.g., cross-entropy) along the path.
  • Effect of hyperparameters: learning rate, optimizer choice for path optimization, number of nodes N, and training duration for the path are not ablated.
  • Reliability under BN/statistics mismatch: implications of BatchNorm’s running statistics (if used) for path validity and loss evaluation along the path are not discussed.
  • Practical deployment: storing and interpolating among many checkpoints may be costly; strategies for compressing or parameterizing the path (e.g., low-rank curves) are not explored.
  • Visualization and diagnostics: no low-dimensional projections or intrinsic-coordinate visualizations to confirm that the found paths align with geodesics in distribution space.
  • Failure characterization: conditions under which geodesic optimization fails (e.g., highly narrow models, severe permutation misalignment, extreme nonconvexity) are not identified.

Glossary

  • BatchNorm: A normalization technique that standardizes activations within a mini-batch to stabilize and speed up training in deep networks. "This is due to BatchNorm layers lacking invariance to permutations."
  • CIFAR-10: A widely used benchmark dataset of 60,000 32x32 color images across 10 classes for image classification. "on the CIFAR-10 dataset"
  • Conditional distribution: A probability distribution of outputs given inputs, used to interpret a neural network’s predictive behavior. "interpreted as a conditional distribution"
  • Distribution space: The manifold of parameterized probability distributions a model induces over data and labels. "in the curved distribution space"
  • Energy functional: A quantity whose minimization yields the geodesic of constant velocity between two distributions. "minimizing the energy functional which gives the unique geodesic of constant velocity."
  • Euclidean length: The straight-line length measured in parameter space with standard Euclidean geometry. "this curvature increases the Euclidean length (dashed) compared to the linear path"
  • Fisher information matrix: A matrix capturing the curvature of the parameterized distribution space, equivalent to the Fisher-Rao metric. "This is also known, in broader contexts, as the Fisher information matrix."
  • Fisher-Rao metric: The Riemannian metric on the manifold of probability distributions used to measure lengths and define geodesics. "where gijg_{ij} is the Fisher-Rao metric"
  • Geodesic: The shortest path between two points on a curved space or manifold. "One can then define geodesics as the paths of shortest length between two points in this space."
  • Geodesic mode connectivity: A low-loss connection between trained models achieved via geodesic paths in distribution space. "which we refer to as geodesic mode connectivity."
  • Geodesic optimization: An algorithmic procedure to approximate geodesic paths by minimizing a discretized length or divergence-based functional. "Geodesic optimization achieves mode connectivity for ResNet20 on CIFAR-10."
  • Information Geometry: A field applying differential geometric tools to study parameterized probability distributions and statistical models. "We reframe this in the context of Information Geometry"
  • Jensen-Shannon Divergence (JSD): A symmetric measure of dissimilarity between probability distributions used to define path length in distribution space. "square root Jensen-Shannon Divergence (JSD)"
  • LayerNorm: A normalization technique that standardizes activations across the features of each individual sample. "by the use of LayerNorm instead of BatchNorm"
  • Length functional: The integral expression that measures the length of a path under a given metric, minimized by geodesics. "This path minimizes the length functional defined in Equation \ref{eq:jsd_path}"
  • Linear mode connectivity (LMC): A phenomenon or method where two minima are connected by a straight-line path in parameter space without loss increase. "achieved linear mode connectivity (LMC)"
  • Loss basin: A connected region of the parameter space with low loss values around a minimum. "into the same loss basin as the other"
  • Loss landscape: The function surface mapping model parameters to loss values, often visualized to understand optimization behavior. "mode-connecting paths in the loss landscape."
  • Mode connectivity: The observation that different trained models (minima) can be linked by paths along which the loss remains low. "Mode connectivity is a phenomenon where trained models are connected by a path of low loss."
  • Permutation symmetries: Structural symmetries of neural networks (e.g., permuting neurons or filters) that leave the function unchanged. "natural permutation symmetries of neural network layers"
  • ResNet20: A specific residual neural network architecture with 20 layers used for image classification benchmarks. "ResNet20 on CIFAR-10."
  • Riemannian manifold: A smooth geometric space equipped with a metric that enables measurement of lengths and angles. "is a Riemannian manifold equipped with a Riemannian metric"
  • Riemannian metric: A smoothly varying inner product on the tangent space of a manifold that defines distances and geodesics. "equipped with a Riemannian metric"
  • Weight matching algorithm: A method to align and permute neural network parameters to account for permutation symmetries before path construction. "We then employ the weight matching algorithm of \citet{ainsworth_git_2022}"

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Continue Learning

We haven't generated follow-up questions for this paper yet.

Collections

Sign up for free to add this paper to one or more collections.