Generalized Gumbel-Softmax (GenGS)
- GenGS is a framework for differentiable relaxations of discrete random variables, extending traditional Gumbel-Softmax to count and structured combinatorial domains.
- It employs techniques such as support truncation, linear transformations, and modular invertible mappings combined with convex regularization for low-variance gradient estimation.
- The approach enables efficient optimization in deep generative models, topic modeling, and structured prediction tasks, showing strong empirical improvements over conventional methods.
The Generalized Gumbel-Softmax (GenGS) framework encompasses a family of differentiable relaxation techniques for discrete random variables, providing a pathwise gradient estimator for both finite-support and count/infinite-support distributions, as well as structured combinatorial objects. GenGS extends the original Gumbel-Softmax reparameterization, originally suitable only for categorical or Bernoulli variables, to generic discrete laws and complex combinatorial structures. This is achieved through mechanisms such as support truncation, linear transformation, modular invertible mappings, and strongly-convex regularization, thus enabling low-variance, differentiable optimization in models with discrete or structured latent variables (Jang et al., 2016, Joo et al., 2020, Potapczynski et al., 2019, Paulus et al., 2020).
1. Foundations of Gumbel-Softmax and Its Limitations
The original Gumbel-Softmax estimator is based on the Gumbel-Max trick: samples from a categorical distribution with class probabilities are obtained via
where . As the discrete is non-differentiable, it is replaced by a softmax relaxation at temperature :
As , converges to a one-hot vector. This relaxation enables backpropagation and thus efficient stochastic optimization for models involving discrete random variables. However, the classical Gumbel-Softmax is limited to finite, fixed-support settings (categorical or Bernoulli). It cannot readily handle unbounded count distributions or combinatorial objects such as trees, matchings, or permutations (Jang et al., 2016).
2. Generalization to Arbitrary Discrete Laws: The Truncation and Linearization Approach
GenGS generalizes the Gumbel-Softmax estimator to any discrete random variable, especially those with potentially unbounded (e.g., count) support (Joo et al., 2020). The method proceeds as follows:
- Truncation: For a discrete variable (e.g., Poisson, binomial), select a truncation level , defining a new variable :
The support is thus , with pmf .
- Gumbel-Softmax Relaxation: Apply standard Gumbel-Softmax to the finite pmf :
producing a relaxed one-hot vector .
- Linear Detachment: Transform the relaxed representation back to the original support by
where are the (possibly integer) outcomes of .
By controlling (with recovering the full law), GenGS applies to any discrete law and specializes to the original Gumbel-Softmax when is categorical. The estimator supports both explicit and implicit parameterizations of and allows gradients to flow via the full chain (Joo et al., 2020).
3. Structured and Combinatorial GenGS: Convex Relaxation and Perturbation Models
A further generalization addresses discrete distributions over structured objects—such as -subsets, permutations, matchings, and spanning trees—where the sample space is a combinatorial set (Paulus et al., 2020). The generalized perturbation model framework defines:
where is a random utility vector (typically Gumbel-distributed), is the convex hull, and is a strongly-convex regularizer (e.g., negative entropy, squared Euclidean norm). As , approaches the original combinatorial sample; at finite it is a continuous relaxation in . The relaxation supports efficient gradient estimation via autodiff through , enabling reparameterization for arbitrary perturbation models.
Table: Structured Domains for GenGS Relaxations (Paulus et al., 2020)
| Domain | Convex Polytope | Regularizer |
|---|---|---|
| One-hot (categorical) | Simplex | (softmax) |
| k-subset | , , Fenchel-Young | |
| Permutations (matchings) | Birkhoff polytope | |
| Spanning Trees | Spanning Tree polytope | minor Laplacian |
This framework enables GenGS to be applied to deep generative models, structured prediction, and neural relational inference over combinatorial structures, with problem-specific solvers (e.g., dynamic programming, Sinkhorn iterations, matrix-tree computations) providing the necessary convex optimizations.
4. Modular Invertible Reparameterizations and Infinite-Support Extensions
The "Invertible Gaussian Reparameterization" (IGR) approach generalizes GenGS by introducing invertible mappings from multivariate Gaussian noise to the simplex or other target polytopes (Potapczynski et al., 2019). The process is:
- Sample .
- Transform .
- Apply a modular invertible function , such as Softmax or stick-breaking plus Softmax, to obtain (simplex).
Stick-breaking in particular supports extension to countable or infinite categories by adaptive truncation. The mapping is fully invertible, allowing the density and its Jacobian determinant to be calculated in closed form. In practical inference, a continuous "IGR prior" can replace the categorical prior, yielding tractable KL divergences and improved pathwise gradients. This approach enables closed-form KLs, lower-variance gradients, and direct integration with nonparametric Bayesian architectures.
5. Empirical Performance and Applications
Extensive experiments demonstrate the utility of GenGS across a range of settings:
- Synthetic Problems: GenGS achieves lower loss, gradient variance, and bias in parameter learning for Poisson, binomial, multinomial, and negative binomial objectives compared to REINFORCE, NVIL, MuProp, RELAX, and control variates (Joo et al., 2020).
- Discrete-Latent VAEs: On MNIST and Omniglot, GenGS with Poisson, geometric, and negative binomial priors provides lower negative ELBOs than competing estimators by 10–30 nats; annealing the temperature enhances performance (Joo et al., 2020).
- Topic Modeling: Deep Poisson topic models using GenGS for latent counts (20Newsgroups, RCV1) achieve the best perplexities, outperforming VIMCO, REBAR, and related baselines (Joo et al., 2020).
- Structured Prediction: On neural relational inference, unsupervised parsing, and set/explainer tasks, structured GenGS variants recover more latent structure and achieve superior ELBO or accuracy compared to independent, unstructured, or less tailored relaxations (Paulus et al., 2020).
- Density Modeling and Nonparametrics: Modular, invertible GenGS variants achieve closed-form Kullback-Leibler divergences and have been shown to outperform standard Gumbel-Softmax in both flexibility and empirical performance (Potapczynski et al., 2019).
6. Practical Considerations and Hyperparameter Tuning
Key hyperparameters in GenGS include the truncation level (for count laws or infinite support), temperature , and the choice of convex regularizer . Guidelines include:
- Temperature: Select via cross-validation or annealing; very low may induce high gradient variance.
- Truncation Level: Increase so the probability mass outside is negligible.
- Regularizer: Quadratic is computationally cheap but yields dense solutions; negative entropy yields sparse, high-entropy relaxations. For structured domains, regularizer choice trades off computational tractability and tightness.
- Noise Distribution: Gumbel is canonical for softmax relaxations, but Exponential, Gaussian, and Logistic perturbations are suitable for alternative combinatorial domains or Bernoulli sampling (Paulus et al., 2020).
- Algorithmic Integration: GenGS is implemented as a deterministic computation node; gradients are propagated automatically through the entire sampling and transformation pipeline.
7. Theoretical Insights and Future Directions
GenGS synthesizes pathwise and score-function gradient estimation, enabling low-variance, end-to-end differentiable modeling of arbitrary discrete variables and combinatorial objects. The use of modular invertible transformations, adaptive or hierarchical temperature schemes, and structured relaxations broadens the class of tractable variational models. Open research directions involve tighter control of bias at finite , integration with deep normalizing flows, and the extension to more general infinite-support discrete or combinatorial distributions while maintaining computational efficiency (Jang et al., 2016, Paulus et al., 2020, Potapczynski et al., 2019, Joo et al., 2020).