Papers
Topics
Authors
Recent
Search
2000 character limit reached

Ensemble Gumbel-Softmax (EGS) in NAS

Updated 20 January 2026
  • Ensemble Gumbel-Softmax (EGS) is a differentiable estimator that aggregates multiple relaxed samples to enable multi-operation selection in neural architecture search.
  • It reduces gradient variance by averaging over M independent Gumbel-Softmax samples, facilitating stable and efficient optimization.
  • EGS integrates seamlessly into NAS frameworks, jointly optimizing network weights and architecture parameters to achieve state-of-the-art results on image and language tasks.

The Ensemble Gumbel-Softmax (EGS) estimator is a technique developed to improve the effectiveness and efficiency of differentiable neural architecture search (NAS) by enabling the simultaneous optimization of network architecture and parameters in a fully differentiable, end-to-end manner. EGS builds upon the standard Gumbel-Softmax reparameterization by aggregating several independent relaxed samples, thereby supporting the selection of multiple operations per decision point and providing a low-variance, differentiable gradient estimator suited for gradient-based NAS frameworks (Chang et al., 2019).

1. Standard Gumbel-Softmax and Its Relaxation

The foundation of EGS lies in the Gumbel-Softmax, also known as the Concrete distribution. Given a categorical distribution over KK classes parameterized by unnormalized scores α1,,αK\alpha_1,\ldots,\alpha_K, sampling is achieved by the Gumbel-Max trick:

c=argmaxk(logαk+Gk)c = \arg\max_k ( \log\alpha_k + G_k )

with GkGumbel(0,1)G_k \sim \text{Gumbel}(0,1) for each kk. Since argmax\arg\max is non-differentiable, Gumbel-Softmax introduces a softmax relaxation in the log-domain, yielding a differentiable probability vector

yk=exp((logαk+Gk)/τ)jexp((logαj+Gj)/τ)y_k = \frac{\exp((\log\alpha_k + G_k)/\tau)}{\sum_j \exp((\log\alpha_j + G_j)/\tau)}

where Gk=log(logUk)G_k = -\log(-\log U_k), UkUniform(0,1)U_k \sim \text{Uniform}(0,1) and τ>0\tau>0 is a temperature hyperparameter. As τ0+\tau \to 0^+, the sample approaches a one-hot vector; as τ\tau \to \infty, it resembles a uniform distribution. Differentiability with respect to logαk\log\alpha_k allows gradients to flow through the architecture selection process.

2. Ensemble Sampling Mechanism

While standard Gumbel-Softmax ensures single-operation sampling per edge, EGS enables multi-operation selection essential for richer architectural search spaces. EGS produces MM independent Gumbel-Softmax samples y(1),,y(M)ΔK1y^{(1)},\ldots,y^{(M)} \in \Delta^{K-1} and aggregates them into a binary decision vector y^{0,1}K\hat{y} \in \{0,1\}^K by

y^k=maxm=1..M[hard_onehot(y(m))]k\hat{y}_k = \max_{m=1..M} [ \text{hard\_onehot}(y^{(m)}) ]_k

Alternatively, an ensemble-averaged soft assignment is constructed

yˉ=1Mm=1My(m)\bar{y} = \frac{1}{M} \sum_{m=1}^M y^{(m)}

with discretization performed by retaining the top MM entries or via thresholding, thus allowing up to MM operations to be selected per edge.

3. Differentiability, Gradient Estimation, and Variance Reduction

EGS targets minimization of the expected loss over sampled architectures:

J(α,w)=EApα[L(w;A)]J(\alpha, w) = E_{A \sim p_\alpha}[ L(w; A) ]

where AA is a discrete architecture sampled by EGS, ww are network weights, and α\alpha the architecture logits. The reparameterization trick ensures that each Gumbel draw is a deterministic function of uniform randomness, rendering y(m)(α,U(m))y^{(m)}(\alpha, U^{(m)}) differentiable with respect to α\alpha. Instead of using a single sample gradient, EGS averages over MM samples to compute

gˉ=1Mmg(m)\bar{g} = \frac{1}{M} \sum_m g^{(m)}

with g(m)=L/y(m)g^{(m)} = \partial L / \partial y^{(m)}, and gradients are backpropagated through yˉ\bar{y}. The sample mean reduces variance by a factor of MM, leveraging the independence of y(m)y^{(m)} conditional on α\alpha.

4. Algorithmic Integration into Differentiable NAS

The EGS procedure can be integrated into standard NAS workflows that involve joint or bi-level optimization of weights ww and architecture parameters α\alpha. The NAS search loop proceeds as follows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
for epoch in range(1, T+1):
    for batch in D_train:
        for edge in cell:
            # EGS sampling
            for m in range(1, M+1):
                G_k = -log(-log U_k)
                y_e^{(m)}[k] = softmax((log α_e[k] + G_k) / τ)
            ȳ_e = (1/M) sum_m y_e^{(m)}
            o_e(x) = sum_k ȳ_e[k] * O_k(x)
        output = model_w(x, {ȳ_e})
        loss = ℓ(output, y)
        # Backpropagation through ȳ_e
        w  w  η_w _w loss
        α  α  η_α _α loss
    # Optionally adjust τ schedule
At search termination, discrete architectures are fixed by selecting the top-MM operations per edge according to logits α\alpha.

5. Hyperparameters, Computational Complexity, and Theoretical Guarantees

EGS introduces key hyperparameters: ensemble size MM and temperature τ\tau.

  • Ensemble size MM: Dictates the number of GS samples per edge, controlling search space expressivity and gradient variance. Proposition 3 confirms EGS can represent all binary patterns with up to MM active operations per edge, with search capacity scaling as O(KM)\mathcal{O}(K^M).
  • Temperature τ\tau: Lower τ\tau produces near one-hot outputs, favoring discrete architectures; higher τ\tau yields softer selections and easier optimization. Annealing τ\tau from approximately $1.0$ to $0.1$ during search is common.
  • Computational cost: Sampling and combining across MM GS samples incurs cost O(MK)\mathcal{O}(M \cdot K) per edge, or O(EMK)\mathcal{O}(E \cdot M \cdot K) for a cell with EE edges. Empirically, M=48M = 4 \ldots 8 provides sufficient coverage at a modest overhead.
  • Theoretical properties: Proposition 2 states EGS is a monotonic increasing set function of the selection probabilities, and variance of the estimator scales inversely with MM.

6. Empirical Evaluation Across Tasks

EGS has demonstrated competitive or superior performance in canonical NAS benchmarks at low computational cost. For CIFAR-10, DARTS-EGS (with M=7M=7) achieves 2.79%2.79\% error in $1$ GPU-day, improving upon DARTS and SNAS. On ImageNet (mobile setting), DARTS-EGS attains 24.9%24.9\% top-1 accuracy in $1.5$ days, outperforming DARTS and NASNet at orders-of-magnitude lower search cost. In PTB language modeling, DARTS-EGS (M=7M=7) reaches $57.1 / 55.3$ perplexity in $0.5$ days, while transfer to WT2 yields $66.5 / 64.2$ perplexity. Ablation on MM reveals monotonically improved results as ensemble size grows: | M | CIFAR-10 Error (%) | |---|-------------------| | 1 | 3.38 | | 4 | 3.05 | | 7 | 2.79 | | 9 | 2.73 |

Consistently, larger MM enhances final model accuracy and reduces gradient variance at acceptable computational cost.

7. Context, Significance, and Integration with NAS Frameworks

EGS generalizes prior gradient-based NAS searchers by supporting multi-operation selection and stabilizing gradients via ensembling, thereby extending the utility of continuous relaxation frameworks (DARTS-style) where joint optimization of discrete architecture and network weights is required. The capacity to sample exponentially more architectures and reduce estimator variance positions EGS as a principal technique for efficient, scalable architecture search. EGS is compatible with both joint and bi-level optimization schemes, and naturally integrates into feed-forward search spaces with arbitrary connectivity.

Ensemble Gumbel-Softmax thus extends the classical Gumbel-Softmax by facilitating diverse, gradient-efficient search over architectural building blocks, yielding state-of-the-art performance on both image and language tasks at practical search cost (Chang et al., 2019).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Ensemble Gumbel-Softmax (EGS).