Papers
Topics
Authors
Recent
Search
2000 character limit reached

Rejection Sampling with Autodifferentiation (RSA)

Updated 30 January 2026
  • RSA is an algorithmic framework that integrates classical rejection sampling with autodifferentiation, yielding tractable and low-variance gradient estimates.
  • It employs reparameterizable accept/reject mechanisms and smooth approximations to overcome non-differentiable operations, enhancing parameter learning in latent-variable models.
  • RSA has demonstrated significant improvements in probabilistic inference and generative modeling, optimizing acceptance rates and reducing gradient noise in simulation-based tasks.

Rejection Sampling with Autodifferentiation (RSA) is an algorithmic and statistical framework designed to connect classical rejection sampling with modern gradient-based optimization and probabilistic modeling. By rendering the sampling mechanism differentiable, RSA achieves tractable, low-variance gradient estimates, facilitates efficient parameter learning in latent-variable models, and enables adaptive construction of expressive distributions for tasks in machine learning and simulation-based inference. RSA’s primary technical innovation is the introduction of reparameterizable accept–reject samplers or the use of sample reweighting, which allows direct integration with autodiff frameworks such as PyTorch, TensorFlow, or JAX.

1. Motivating Problems and Statistical Foundations

Rejection sampling provides a method for drawing independent samples from an intractable or unnormalized target density f(x)f(x), using a tractable proposal g(x)g(x) and an envelope constant Csupxf(x)/g(x)C \geq \sup_x f(x)/g(x). One accepts xgx \sim g iff u<f(x)/(Cg(x))u < f(x)/(C g(x)), where uUniform[0,1]u \sim \mathrm{Uniform}[0,1], guaranteeing that the accepted points follow f(x)f(x) up to normalization.

In machine learning, stochastic variational inference (SVI) and generative models (VAEs, normalizing flows, etc.) employ tractable proposals to approximate intractable posteriors. Poorly matched proposals lead to high-variance Monte Carlo gradient estimates, particularly for score-function estimators. Classical rejection sampling can, in principle, yield samples closer to the true posterior, reducing gradient variance but introducing a discrete, non-differentiable accept/reject branch that blocks gradient flow. RSA addresses this by reconstructing a continuous, differentiable mapping from the parameters of interest to the statistical properties of accepted samples, thereby enabling unbiased gradient flow through the full generative pipeline (grover et al., 2018, Heller et al., 2024).

2. Core Algorithms and Mechanisms

RSA methods are characterized by their differentiation strategies:

A. Direct Reparametrization or Accept/Reject Networks

  • Variational Rejection Sampling (grover et al., 2018): Given a variational proposal qϕ(zx)q_\phi(z|x) and unnormalized joint γp(z)=pθ(x,z)\gamma_p(z)=p_\theta(x,z), RSA constructs a new proposal R(zx)R(z|x) using acceptance probability

a(z;x)=min{1,pθ(x,z)Mqϕ(zx)}a(z;x) = \min \left\{ 1, \frac{p_\theta(x,z)}{M q_\phi(z|x)} \right\}

or its smooth approximation via softplus:

loga(z;x)=softplus(l(z;x))\log a(z;x) = -\mathrm{softplus}(l(z;x))

where l(z;x)=logpθ(x,z)+logqϕ(zx)Tl(z;x)=-\log p_\theta(x,z) + \log q_\phi(z|x) - T (T=logMT=-\log M). RSA samples zqϕz \sim q_\phi, computes aa, samples uUniform[0,1]u \sim \mathrm{Uniform}[0,1], and accepts/rejects accordingly, repeating until acceptance.

pϕ(z)=qϕ(z)aϕ(z)Zϕp_\phi(z) = \frac{q_\phi(z) a_\phi(z)}{Z_\phi}

with Zϕ=qϕ(z)aϕ(z)dzZ_\phi = \int q_\phi(z) a_\phi(z) dz. Gradients are estimated by reparameterizing qϕ(z)q_\phi(z) and using Monte Carlo sampling: the MC estimate of ZϕZ_\phi and its gradients are differentiable via autodiff.

B. Event Reweighting and Replay

  • Event reweighting (Heller et al., 2024): Running a rejection sampler at fixed parameters, the full history of accepted and rejected draws is retained. For each accepted xax_a with rejects {xrj}\{x_r^j\}, the weight for a new parameter setting θ\theta is

w(xa,{xrj};θ)=α(xa;θ)α(xa;θ0)j=1n11α(xrj;θ)1α(xrj;θ0)w(x_a,\{x_r^j\}; \theta) = \frac{\alpha(x_a;\theta)}{\alpha(x_a;\theta_0)} \prod_{j=1}^{n-1} \frac{1-\alpha(x_r^j;\theta)}{1-\alpha(x_r^j;\theta_0)}

with α(x;θ)=p(x;θ)/(Mq(x))\alpha(x;\theta) = p(x;\theta) / (M q(x)). Expectations under a new parameter set are computed by weighted averages over the base sample population, and gradients propagate via the smooth dependence of ww on θ\theta.

C. Gradient-Refined Proposals and Empirical Bound Tightening

  • Gradient-refined rejection sampling (Raff et al., 2023): For a target f(x)f(x) and proposal g(x;θ)g(x;\theta), the envelope constant C(θ)C(\theta) is adaptively minimized via gradient descent on surrogate objectives (e.g., hinge loss or empirical softmax ratios), efficiently shrinking CC and boosting acceptance rates. Refinement uses autodiff through the proposal density g(x;θ)g(x;\theta).

3. Gradient Estimation and Autodifferentiation

RSA enables direct computation of gradients via two principal approaches:

ϕR-ELBO=Covr(A(z),(1σ(l(z;x)))ϕlogqϕ(zx))\nabla_\phi \mathrm{R}\text{-}\mathrm{ELBO} = \mathrm{Cov}_r(A(z), (1-\sigma(l(z;x))) \nabla_\phi \log q_\phi(z|x) )

θR-ELBO=Er[θlogpθ(x,z)]Covr(A(z),σ(l(z;x))θlogpθ(x,z))\nabla_\theta \mathrm{R}\text{-}\mathrm{ELBO} = \mathbb{E}_r[\nabla_\theta \log p_\theta(x,z)] - \mathrm{Cov}_r(A(z), \sigma(l(z;x)) \nabla_\theta \log p_\theta(x,z))

ϕZϕ=Eqϕ[ϕaϕ(z)]\nabla_\phi Z_\phi = E_{q_\phi}[ \nabla_\phi a_\phi(z) ]

The MC estimation enables differentiated log-density and ELBO computations within VAE or flow architectures.

θw(xa,{xrj};θ)=wθ(logα(xa;θ)+j=1n1log(1α(xrj;θ)))\nabla_\theta w(x_a,\{x_r^j\};\theta) = w \cdot \nabla_\theta \left( \log \alpha(x_a;\theta) + \sum_{j=1}^{n-1} \log(1-\alpha(x_r^j;\theta)) \right)

Chain rule application and autodiff frameworks facilitate the unbiased estimation of sensitivity of observables to model parameters.

4. Practical Considerations and Performance Characteristics

Implementation of RSA requires attention to computational trade-offs:

  • Envelope Constant Selection: Acceptance rates depend sensitively on MM (or C^\hat{C}), which must uniformly bound f/qf/q over support. Adaptive estimation via empirical supremum tracking or gradient refinement is standard (Raff et al., 2023).
  • Computational Cost: Expected number of proposal draws per accepted sample scales as 1/a(x)1/\langle a(x) \rangle or 1/Zϕ1/Z_\phi. In high-rejection scenarios, batching and vectorization are recommended. Truncation and forced acceptance limit computational overhead.
  • Variance Reduction: Covariance estimators, batch-based averaging, moving averages on ZϕZ_\phi, and control variates (such as subtracting sample means) reduce gradient noise, as do optimal proposal selection and regular updating of auxiliary parameters.
  • Integration: Encapsulation in autodiff frameworks requires careful management of the accept/reject loop, consistent recording of acceptance histories, and numerically stable applications of softplus/logarithm/sigmoid approximations.
  • Memory: Event history storage for reweighting may demand significant memory; strategies such as zero-padding, random seed regeneration, and in-place computation mitigate resource requirements (Heller et al., 2024).

5. Applications and Benchmark Results

RSA has demonstrated utility across several domains:

  • Variational Inference (MNIST SBNs, Poisson-like posteriors) (grover et al., 2018):
    • RSA achieved improvements of 3.71 nats (single-sample) and 0.21 nats (multi-sample) over baselines for marginal log-likelihood estimation.
    • Lower acceptance quantiles yielded tighter ELBOs with an increase in proposal samples, illustrating compute/accuracy trade-off.
  • Low-Dimensional Distribution Sampling (Raff et al., 2023):
    • RSA improved acceptance rates by up to 38×38\times over KDE-based NNARS in 1D “peakiness” benchmarks, reached 95%95\% in 1D multimodal “clutter,” and matched best alternatives in higher dimensions (up to d=7d=7).
    • Empirical tests (KS, Cramér) detected no statistical deviations, confirming asymptotic exactness.
  • Flexible Priors in VAEs and Flows (Bauer et al., 2018, Stimper et al., 2021):
    • LARS priors reduced VAE NLL by $0.5$–$1.5$ nats; hierarchical architectures saw 1\sim 1 nat improvement.
    • Learned rejection-based bases enabled normalizing flows to model distributions with complex topology while preserving bijectivity.
  • Simulation-Based Inference and Model Fitting (Heller et al., 2024):
    • RSA facilitated efficient parameter estimation using both binned and unbinned machine learning observables.
    • Case study fitting Lund hadronization models employed RSA for smooth estimation of multiplicity, DeepSets classifier scores, and joint distribution fits, achieving confidence-ellipse shrinkage when incorporating ML-based observables.

6. Methodological Extensions, Limitations, and Recommendations

Key extensions include:

  • Adaptive Proposal Refinement: Simultaneous learning and tightening of proposal envelopes via gradient methods.
  • Reweighting Across Parameterizations: Mixing samples from multiple base parameterizations reduces variance and expands effective coverage.
  • Handling Discrete Branches and Nontrivial Generative Histories: RSA extends naturally to branching Monte Carlo algorithms by capturing full decision histories for likelihood reweighting (Heller et al., 2024).

Limitations:

  • Envelope Constant Estimation: For arbitrary or highly multimodal targets, tight bounding of MM or CC may require extensive pilot sampling or specialized surrogates.
  • Black-Box Generators: RSA assumes access to the explicit densities q(x;ϕ)q(x;\phi) and p(x;θ)p(x;\theta); for simulators lacking closed-form densities, classifier-based or surrogate reweighting may be needed.
  • Memory Scaling: Extremely large final states (e.g., hundreds of particles in physics simulation) challenge the storage and computation of histories and weights; in-place or seed-based solutions are used in practice.

Recommended practices:

  • Choose proposals qq close to pp to minimize rejections and the variance of importance weights.
  • Update envelope constants or thresholds dynamically, but hold fixed during any single gradient computation.
  • Monitor both acceptance rates and gradient variance, adjusting batch sizes, quantiles, or refinement steps as needed.
  • For unbinned and ML-based observables, compute loss gradients via importance-weighted sums to exploit the continuous nature of RSA weights.

7. Connections to Broader Research and Future Directions

RSA integrates classical Monte Carlo principles with current trends in differentiable programming and simulation-based inference. It subsumes prior adaptive rejection methods by leveraging empirical and gradient information, and links to score-function and reparameterization-based gradient estimators in latent-variable models. RSA’s flexibility in handling expressive priors, complex simulation pipelines, and ML-based summarization suggests wide applicability in likelihood-free inference, generative modeling (VAEs, normalizing flows), and scientific simulation.

A plausible implication is that RSA or similar event-reweighting schemes will become standard in domains where simulation and autodiff are coupled for parameter estimation—especially in high energy physics, population dynamics, and other fields relying on generative stochastic models.

References:

(grover et al., 2018) Variational Rejection Sampling (Raff et al., 2023) An Easy Rejection Sampling Baseline via Gradient Refined Proposals (Bauer et al., 2018) Resampled Priors for Variational Autoencoders (Stimper et al., 2021) Resampling Base Distributions of Normalizing Flows (Heller et al., 2024) Rejection Sampling with Autodifferentiation - Case study: Fitting a Hadronization Model

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Rejection Sampling with Autodifferentiation (RSA).