Papers
Topics
Authors
Recent
Search
2000 character limit reached

Flow Annealed Importance Sampling Bootstrap

Published 3 Aug 2022 in cs.LG, q-bio.QM, and stat.ML | (2208.01893v3)

Abstract: Normalizing flows are tractable density models that can approximate complicated target distributions, e.g. Boltzmann distributions of physical systems. However, current methods for training flows either suffer from mode-seeking behavior, use samples from the target generated beforehand by expensive MCMC methods, or use stochastic losses that have high variance. To avoid these problems, we augment flows with annealed importance sampling (AIS) and minimize the mass-covering $\alpha$-divergence with $\alpha=2$, which minimizes importance weight variance. Our method, Flow AIS Bootstrap (FAB), uses AIS to generate samples in regions where the flow is a poor approximation of the target, facilitating the discovery of new modes. We apply FAB to multimodal targets and show that we can approximate them very accurately where previous methods fail. To the best of our knowledge, we are the first to learn the Boltzmann distribution of the alanine dipeptide molecule using only the unnormalized target density, without access to samples generated via Molecular Dynamics (MD) simulations: FAB produces better results than training via maximum likelihood on MD samples while using 100 times fewer target evaluations. After reweighting the samples, we obtain unbiased histograms of dihedral angles that are almost identical to the ground truth.

Citations (59)

Summary

  • The paper introduces FAB, which employs annealed importance sampling with α-divergence (α=2) minimization to reduce importance weight variance.
  • FAB uses a prioritized replay buffer to efficiently reuse AIS samples, reducing the need for costly MCMC evaluations.
  • FAB outperforms competing methods on tasks like 2D Gaussian mixtures and alanine dipeptide, accurately approximating multimodal target distributions.

Flow Annealed Importance Sampling Bootstrap: A Detailed Summary

This paper introduces Flow Annealed Importance Sampling Bootstrap (FAB), a novel method for training normalizing flows to approximate complex, multimodal target distributions. FAB addresses limitations in existing approaches, such as mode-seeking behavior, reliance on expensive MCMC samples, and high variance stochastic losses. The method leverages annealed importance sampling (AIS) and minimizes the mass-covering α\alpha-divergence with α=2\alpha=2, which is shown to minimize importance weight variance.

Key Contributions

  • α\alpha-Divergence Minimization with AIS: FAB uses the α\alpha-divergence with α=2\alpha=2 as its training objective. This choice encourages mass-covering behavior and reduces variance when importance sampling is used to correct for bias. The method employs AIS with the flow as the initial distribution and p2/qp^2/q as the target, where pp is the target distribution and qq is the flow.
  • Replay Buffer for Sample Re-use: To reduce computational costs, FAB introduces a prioritized replay buffer that re-uses AIS samples during flow updates. This buffer allows the method to efficiently learn from past samples, improving the stability and speed of training.
  • Demonstrated Performance on Challenging Problems: FAB is evaluated on a 2D Gaussian mixture model, the 32-dimensional "Many Well" problem, and the Boltzmann distribution of alanine dipeptide. The method outperforms competing approaches, demonstrating its ability to accurately approximate complex, multimodal targets without relying on samples from those distributions.

Methodological Details

Normalizing Flows and the α\alpha-Divergence

Normalizing flows use invertible transformations to map a simple base distribution to a complex target distribution. The α\alpha-divergence, defined as:

Dα(pq)=xp(x)αq(x)1αdxα(1α),D_{\alpha}(p \| q)=-\frac{\int_{x} p(\mathbf{x})^{\alpha} q(\mathbf{x})^{1-\alpha} d x}{\alpha(1-\alpha)},

is a measure of dissimilarity between two probability distributions, pp and qq. The behavior of the α\alpha-divergence changes as α\alpha varies; for α0\alpha \leq 0, it is mode-seeking, while for α1\alpha \geq 1, it is mass-covering (Figure 1). Figure 1

Figure 1: Illustration of unnormalized Gaussian approximating distributions qq, shown in red, that minimize the α\alpha-divergence for different values of α\alpha with respect to a bimodal target distribution pp, shown in blue.

The choice of α=2\alpha=2 in FAB minimizes the variance of importance sampling weights, which is beneficial for correcting bias in samples from the flow.

Annealed Importance Sampling (AIS)

AIS is used to generate samples in regions where the flow is a poor approximation of the target. AIS constructs a sequence of intermediate distributions, pip_i, between the initial distribution qq and the target distribution pp. Samples are drawn from qq and then iteratively transitioned through the intermediate distributions using MCMC. The final sample, xMx_M, is accompanied by an importance weight, wAISw_\text{AIS}, which accounts for the transformations applied during the transitions.

Flow AIS Bootstrap

FAB trains a flow qq to approximate a target pp by minimizing Dα=2(pq)D_{\alpha=2}(p \| q), which is estimated with AIS using qq as initial distribution and p2/qp^2 / q as target. The latter is the minimum variance importance sampling distribution for estimating the Dα=2(pq)D_{\alpha=2}(p \| q) loss. This process can be seen as a form of bootstrapping, where the flow qq is fit using samples generated by itself, after they have been improved with AIS to fit p2/qp^2/q. The gradient of the loss function is estimated using:

$\nabla_\theta \mathcal{L} (\theta ) = -\operatorname{E}_{\text{AIS} \left[ w_{\text{AIS} \nabla_\theta \log q_\theta(\bar{x}_{\text{AIS}) \right],$

where $\bar{x}_{\text{AIS}$ and $w_{\text{AIS}$ are the samples and weights generated by AIS when targeting p2/qp^2/q. Gradients are stopped during the AIS sampling to ensure stable training.

Replay Buffer

To reduce the computational cost of AIS, FAB employs a prioritized replay buffer. The buffer stores AIS samples and their corresponding weights. During training, samples are drawn from the buffer with probability proportional to their AIS weights. A correction factor is applied to account for the difference between the current flow parameters and those used to generate the samples in the buffer. This replay buffer allows for the re-use of previously generated samples, reducing the number of target evaluations required.

Experimental Results

FAB was evaluated on three challenging problems:

  • 2D Gaussian Mixture Model: FAB, with and without the replay buffer, accurately fit all modes of a complex 2D Gaussian mixture model. Other methods, such as those based on KL divergence, often failed to capture all modes due to mode-seeking behavior. (Figure 2) Figure 2

    Figure 2: Contour lines for the target distribution pp and samples (blue discs) drawn from the approximation qθq_\theta obtained by different methods on the mixture of Gaussians problem.

  • 32D "Many Well" Problem: FAB demonstrated strong performance on the 32D "Many Well" problem, a high-dimensional distribution with a large number of modes. FAB, particularly with the replay buffer, achieved results comparable to training with maximum likelihood, while other methods struggled. (Figure 3) Figure 3

    Figure 3: Samples from qθq_\theta and target contours for marginal distributions over the first four elements of xx in the 32 dimensional Many Well Problem.

  • Alanine Dipeptide Boltzmann Distribution: FAB was used to learn the Boltzmann distribution of alanine dipeptide, a challenging problem in molecular dynamics. FAB produced better results than training via maximum likelihood on MD samples, while using 100 times fewer target evaluations. Unbiased histograms of dihedral angles, nearly identical to the ground truth, were obtained after reweighting samples (Figure 4). Figure 4

    Figure 4: From left to right, Ramachandran plots of the ground truth generated by MD, a flow model trained by ML on MD samples, and by FAB using a replay buffer before and after reweighting samples to eliminate bias.

The paper relates FAB to other methods that combine flows with MCMC, such as Stochastic Normalizing Flows (SNFs) and Continual Repeated Flow Annealed Transport (CRAFT). FAB differs from these methods in its use of the α\alpha-divergence and the AIS bootstrap mechanism. The work also discusses connections to research on improving transition kernels and intermediate distributions in MCMC and AIS.

Implications and Future Directions

The results of this paper suggest that FAB is a promising approach for training normalizing flows to approximate complex target distributions. The method's ability to avoid mode-seeking behavior and reduce reliance on expensive MCMC samples makes it well-suited for problems in various fields, including molecular dynamics, computational physics, and Bayesian inference.

Future work could explore:

  • Scaling FAB to even higher-dimensional problems, such as modeling the Boltzmann distribution of larger proteins.
  • Combining FAB with more expressive flow architectures or with techniques for learning transition kernels in MCMC.
  • Investigating the use of control variates and defensive importance sampling to further reduce variance in the loss estimate.
  • Applying FAB to other types of models, such as diffusion models.

Conclusion

FAB offers a novel and effective approach for training normalizing flows. By combining α\alpha-divergence minimization with an AIS bootstrapping mechanism, FAB can accurately approximate complex, multimodal target distributions without relying on samples from those distributions. The use of a prioritized replay buffer further enhances the efficiency and stability of the method. The experimental results demonstrate the potential of FAB for solving challenging problems in various scientific and engineering domains.

Paper to Video (Beta)

No one has generated a video about this paper yet.

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 2 tweets with 72 likes about this paper.