Rejection Sampling with Autodifferentiation (RSA)
- RSA is an algorithmic framework that integrates classical rejection sampling with autodifferentiation, yielding tractable and low-variance gradient estimates.
- It employs reparameterizable accept/reject mechanisms and smooth approximations to overcome non-differentiable operations, enhancing parameter learning in latent-variable models.
- RSA has demonstrated significant improvements in probabilistic inference and generative modeling, optimizing acceptance rates and reducing gradient noise in simulation-based tasks.
Rejection Sampling with Autodifferentiation (RSA) is an algorithmic and statistical framework designed to connect classical rejection sampling with modern gradient-based optimization and probabilistic modeling. By rendering the sampling mechanism differentiable, RSA achieves tractable, low-variance gradient estimates, facilitates efficient parameter learning in latent-variable models, and enables adaptive construction of expressive distributions for tasks in machine learning and simulation-based inference. RSA’s primary technical innovation is the introduction of reparameterizable accept–reject samplers or the use of sample reweighting, which allows direct integration with autodiff frameworks such as PyTorch, TensorFlow, or JAX.
1. Motivating Problems and Statistical Foundations
Rejection sampling provides a method for drawing independent samples from an intractable or unnormalized target density , using a tractable proposal and an envelope constant . One accepts iff , where , guaranteeing that the accepted points follow up to normalization.
In machine learning, stochastic variational inference (SVI) and generative models (VAEs, normalizing flows, etc.) employ tractable proposals to approximate intractable posteriors. Poorly matched proposals lead to high-variance Monte Carlo gradient estimates, particularly for score-function estimators. Classical rejection sampling can, in principle, yield samples closer to the true posterior, reducing gradient variance but introducing a discrete, non-differentiable accept/reject branch that blocks gradient flow. RSA addresses this by reconstructing a continuous, differentiable mapping from the parameters of interest to the statistical properties of accepted samples, thereby enabling unbiased gradient flow through the full generative pipeline (grover et al., 2018, Heller et al., 2024).
2. Core Algorithms and Mechanisms
RSA methods are characterized by their differentiation strategies:
A. Direct Reparametrization or Accept/Reject Networks
- Variational Rejection Sampling (grover et al., 2018): Given a variational proposal and unnormalized joint , RSA constructs a new proposal using acceptance probability
or its smooth approximation via softplus:
where (). RSA samples , computes , samples , and accepts/rejects accordingly, repeating until acceptance.
- Learned Accept/Reject Sampling (LARS) (Bauer et al., 2018), as in resampled VAE priors and normalizing flow bases (Stimper et al., 2021): Use a neural network to parameterize acceptance probabilities. The accepted sample density is
with . Gradients are estimated by reparameterizing and using Monte Carlo sampling: the MC estimate of and its gradients are differentiable via autodiff.
B. Event Reweighting and Replay
- Event reweighting (Heller et al., 2024): Running a rejection sampler at fixed parameters, the full history of accepted and rejected draws is retained. For each accepted with rejects , the weight for a new parameter setting is
with . Expectations under a new parameter set are computed by weighted averages over the base sample population, and gradients propagate via the smooth dependence of on .
C. Gradient-Refined Proposals and Empirical Bound Tightening
- Gradient-refined rejection sampling (Raff et al., 2023): For a target and proposal , the envelope constant is adaptively minimized via gradient descent on surrogate objectives (e.g., hinge loss or empirical softmax ratios), efficiently shrinking and boosting acceptance rates. Refinement uses autodiff through the proposal density .
3. Gradient Estimation and Autodifferentiation
RSA enables direct computation of gradients via two principal approaches:
- Covariance-Type Estimators (grover et al., 2018):
- Backpropagation Through MC Estimates (Bauer et al., 2018):
The MC estimation enables differentiated log-density and ELBO computations within VAE or flow architectures.
- Event-Weight Gradients (Heller et al., 2024):
Chain rule application and autodiff frameworks facilitate the unbiased estimation of sensitivity of observables to model parameters.
4. Practical Considerations and Performance Characteristics
Implementation of RSA requires attention to computational trade-offs:
- Envelope Constant Selection: Acceptance rates depend sensitively on (or ), which must uniformly bound over support. Adaptive estimation via empirical supremum tracking or gradient refinement is standard (Raff et al., 2023).
- Computational Cost: Expected number of proposal draws per accepted sample scales as or . In high-rejection scenarios, batching and vectorization are recommended. Truncation and forced acceptance limit computational overhead.
- Variance Reduction: Covariance estimators, batch-based averaging, moving averages on , and control variates (such as subtracting sample means) reduce gradient noise, as do optimal proposal selection and regular updating of auxiliary parameters.
- Integration: Encapsulation in autodiff frameworks requires careful management of the accept/reject loop, consistent recording of acceptance histories, and numerically stable applications of softplus/logarithm/sigmoid approximations.
- Memory: Event history storage for reweighting may demand significant memory; strategies such as zero-padding, random seed regeneration, and in-place computation mitigate resource requirements (Heller et al., 2024).
5. Applications and Benchmark Results
RSA has demonstrated utility across several domains:
- Variational Inference (MNIST SBNs, Poisson-like posteriors) (grover et al., 2018):
- RSA achieved improvements of 3.71 nats (single-sample) and 0.21 nats (multi-sample) over baselines for marginal log-likelihood estimation.
- Lower acceptance quantiles yielded tighter ELBOs with an increase in proposal samples, illustrating compute/accuracy trade-off.
- Low-Dimensional Distribution Sampling (Raff et al., 2023):
- RSA improved acceptance rates by up to over KDE-based NNARS in 1D “peakiness” benchmarks, reached in 1D multimodal “clutter,” and matched best alternatives in higher dimensions (up to ).
- Empirical tests (KS, Cramér) detected no statistical deviations, confirming asymptotic exactness.
- Flexible Priors in VAEs and Flows (Bauer et al., 2018, Stimper et al., 2021):
- LARS priors reduced VAE NLL by $0.5$–$1.5$ nats; hierarchical architectures saw nat improvement.
- Learned rejection-based bases enabled normalizing flows to model distributions with complex topology while preserving bijectivity.
- Simulation-Based Inference and Model Fitting (Heller et al., 2024):
- RSA facilitated efficient parameter estimation using both binned and unbinned machine learning observables.
- Case study fitting Lund hadronization models employed RSA for smooth estimation of multiplicity, DeepSets classifier scores, and joint distribution fits, achieving confidence-ellipse shrinkage when incorporating ML-based observables.
6. Methodological Extensions, Limitations, and Recommendations
Key extensions include:
- Adaptive Proposal Refinement: Simultaneous learning and tightening of proposal envelopes via gradient methods.
- Reweighting Across Parameterizations: Mixing samples from multiple base parameterizations reduces variance and expands effective coverage.
- Handling Discrete Branches and Nontrivial Generative Histories: RSA extends naturally to branching Monte Carlo algorithms by capturing full decision histories for likelihood reweighting (Heller et al., 2024).
Limitations:
- Envelope Constant Estimation: For arbitrary or highly multimodal targets, tight bounding of or may require extensive pilot sampling or specialized surrogates.
- Black-Box Generators: RSA assumes access to the explicit densities and ; for simulators lacking closed-form densities, classifier-based or surrogate reweighting may be needed.
- Memory Scaling: Extremely large final states (e.g., hundreds of particles in physics simulation) challenge the storage and computation of histories and weights; in-place or seed-based solutions are used in practice.
Recommended practices:
- Choose proposals close to to minimize rejections and the variance of importance weights.
- Update envelope constants or thresholds dynamically, but hold fixed during any single gradient computation.
- Monitor both acceptance rates and gradient variance, adjusting batch sizes, quantiles, or refinement steps as needed.
- For unbinned and ML-based observables, compute loss gradients via importance-weighted sums to exploit the continuous nature of RSA weights.
7. Connections to Broader Research and Future Directions
RSA integrates classical Monte Carlo principles with current trends in differentiable programming and simulation-based inference. It subsumes prior adaptive rejection methods by leveraging empirical and gradient information, and links to score-function and reparameterization-based gradient estimators in latent-variable models. RSA’s flexibility in handling expressive priors, complex simulation pipelines, and ML-based summarization suggests wide applicability in likelihood-free inference, generative modeling (VAEs, normalizing flows), and scientific simulation.
A plausible implication is that RSA or similar event-reweighting schemes will become standard in domains where simulation and autodiff are coupled for parameter estimation—especially in high energy physics, population dynamics, and other fields relying on generative stochastic models.
References:
(grover et al., 2018) Variational Rejection Sampling (Raff et al., 2023) An Easy Rejection Sampling Baseline via Gradient Refined Proposals (Bauer et al., 2018) Resampled Priors for Variational Autoencoders (Stimper et al., 2021) Resampling Base Distributions of Normalizing Flows (Heller et al., 2024) Rejection Sampling with Autodifferentiation - Case study: Fitting a Hadronization Model