Doubly Stochastic Adversarial Autoencoders
- The paper introduces a doubly stochastic mechanism that replaces a fixed discriminator with random feature sampling to smooth gradients and mitigate mode collapse.
- The methodology combines reconstruction loss with an adversarial penalty derived from stochastic random feature maps, bridging GAN and MMD approaches.
- Experimental evaluations on MNIST show that DS-AAE improves sample diversity and latent space exploration compared to traditional AAEs and MMD-AEs.
A Doubly Stochastic Adversarial Autoencoder (DS-AAE) is a probabilistic autoencoder architecture in which the conventional adversary of an Adversarial Autoencoder (AAE) is replaced by a stochastic function sampled from a space of random frequency features. This introduces a second source of algorithmic randomness, which smooths gradients, regularizes the adversarial training, and mitigates mode collapse. DS-AAE interpolates between Maximum Mean Discrepancy Autoencoders (MMD-AE) and classical AAEs/GANs depending on the choice of random feature distribution and parameterizations (Azarafrooz, 2018).
1. Motivation and Background
Variational Autoencoders (VAE) and AAEs both enforce a prescribed prior on the latent code. VAEs achieve this via a closed-form KL divergence penalty matching the latent aggregated posterior to the prior , but this can yield sample blurriness and under-exploration of multimodal posteriors since the KL term pushes each towards individually. AAEs replace the KL penalty with a GAN-style discriminator, yielding objectives of the form
AAEs can generate sharper samples, but adversarial training may suffer from mode collapse: the discriminator can quickly distinguish “fake” codes, pushing the encoder to collapse to a few latent modes to fool the discriminator.
DS-AAE replaces the deterministic discriminator with a space of stochastic functions, injecting additional randomness. This mechanism
- Smooths gradients for the encoder,
- Prevents overfitting of the adversary,
- Encourages the generator to explore more latent modes,
- Reduces mode collapse.
2. Model Components and Stochastic Adversary
The DS-AAE comprises an encoder , which deterministically maps input to latent code ; a decoder/generator ; and an imposed prior (commonly ).
Conventional AAEs use a single neural network discriminator. DS-AAE instead defines a function class: where is a “doubly stochastic gradient feature” constructed by sampling a random frequency from a measure —often Gaussian for RBF kernels—and defining a random feature map
with , for example, set to . Each function utilized during training depends on a fresh random draw , ensuring the stochasticity of the adversary.
3. Doubly Stochastic Minimax Objective
DS-AAE is formulated as a doubly stochastic saddle point problem: with objective
where the stochastic divergence is
Maximizing over gives
For large numbers of random features , this converges to the MMD divergence; if is parameterized by a deep network and is degenerate, the formulation recovers the GAN divergence.
The overall loss includes:
- The usual reconstruction penalty ,
- The adversarial penalty ,
- A constraint or regularization on .
4. Training Algorithm
Training in DS-AAE alternates between batch sampling of data and batch sampling of random features. The two sources of randomness are crucial for doubly stochastic optimization. The training loop proceeds as follows:
- Sample a minibatch of data .
- Encode: .
- Sample prior codes .
- Sample random features , compute features .
- Compute the stochastic gradient terms:
- ,
- .
- Update the adversary () by ascending the objective, including regularization.
- Update the encoder and decoder parameters , using gradients from the reconstruction loss and the adversarial penalty.
Increasing the number of feature samples or using control variates mitigates variance introduced by random feature sampling.
5. Theoretical Properties
The convergence of the algorithm leverages results from stochastic optimization in Reproducing Kernel Hilbert Spaces (RKHS) [6], demonstrating that, when the adversary step size is small, iterates remain in the RKHS defined by . The minimax optimization thus converges to a stationary Nash point.
Introducing randomness via prevents the adversary from perfectly overfitting to the current generator, which compels the generator to explore the latent space more broadly, encouraging the discovery of additional latent modes and supporting a more uniform coverage of .
6. Experimental Evaluation
Experiments were conducted primarily on MNIST (28×28), with additional preliminary results on CIFAR-10. Network architectures employed three fully-connected layers (1024→512→216), with ReLU activations and a final sigmoid for the decoder. The latent dimension was set to 6 for DS-AAE, with 20% input dropout and an Adam optimizer (learning rate 0.001). Minibatch and random feature batch sizes were both set to 1000 (RBF, ).
Parzen-window log-likelihoods for 10,000 MNIST samples are summarized below:
| Model | Parzen LL (mean ± std) |
|---|---|
| GAN [3] | 225 ± 2 |
| GMMN+AE [5] | 282 ± 2 |
| AAE [1] | 340 ± 2 |
| MMD-AE [5] | 228 ± 1.6 |
| DS-AAE | 243.2 ± 1.7 |
DS-AAE samples display higher visual diversity, with more heterogeneous digit styles across generated panels compared to the relatively homogeneous samples produced by standard AAE or MMD-AE. Latent space interpolations are sharp and cover multiple Gaussian modes, indicating a reduction in mode collapse and improved coverage of the latent manifold.
7. Extensions and Open Questions
DS-AAE demonstrates that replacing a fixed discriminator with a stochastic function space regularizes minimax training and promotes exploration, yielding loosely a continuous spectrum between GAN-based and MMD-based regularization. The doubly stochastic scheme (data plus random features) is integral to this effect.
However, DS-AAE is sensitive to the size of data and feature batches; using small batch sizes diminishes exploration. Its convergence may be slower than that of deterministic AAEs due to the additional randomness. Directions for further research include convolutional DS-AAEs, adaptive random feature sampling strategies, improved variance reduction, and establishing generalization bounds for the doubly stochastic adversary (Azarafrooz, 2018).