- The paper demonstrates that sample variance decay occurs in deep ReLU networks even when total variance is maintained through Kaiming initialization.
- It employs a mean-field theoretical framework and empirical experiments to reveal how ReLU activations and network depth drive the decay process.
- It highlights that techniques like Batch Normalization and sample variance-preserving initialization can counteract representational collapse and improve training dynamics.
Sample Variance Decay in Randomly Initialized ReLU Networks: A Technical Analysis
Introduction
The initialization of weights in deep neural networks is a critical factor influencing training dynamics and convergence. The canonical approach, particularly for ReLU-based architectures, is Kaiming initialization, which aims to preserve the variance of pre-activations across layers. Traditionally, this preservation is interpreted in terms of the total variance, i.e., the variance over both random network initializations and input samples. However, an alternative and arguably more relevant perspective is to consider the variance over samples for a fixed network realization. This paper rigorously analyzes the distinction between these two interpretations, demonstrating that, in deep ReLU networks, the sample variance decays with depth even when total variance is preserved. The implications of this phenomenon are explored both theoretically and empirically, with further analysis of the impact of Batch Normalization and data-dependent initialization schemes.
Theoretical Framework
Decomposition of Variance
The total variance of pre-activations in a given layer can be decomposed as follows:
(σl)2=(ml)2+(vl)2
where (ml)2 is the network-averaged squared sample mean and (vl)2 is the network-averaged sample variance. Kaiming initialization ensures that (σl)2 is constant across layers, but does not guarantee the preservation of (vl)2.
Analytical Results in the Infinite-Width Limit
Assuming infinite width and IID normalized inputs, the propagation of sample mean and variance can be analyzed using the mean-field formalism and the kernel recursion introduced in "Exponential expressivity in deep neural networks through transient chaos" (Poole et al., 2016). The key result is that the expected cosine similarity between activations of two different samples increases monotonically with depth, converging to 1. Consequently, the sample variance (vl)2 decays to zero as depth increases, while the squared sample mean (ml)2 approaches the total variance. The decay is subexponential, as the derivative of the kernel recursion at the fixed point is 1.
Qualitative Mechanism
The decay of sample variance is attributed to the combined effect of random matrix multiplication and the ReLU nonlinearity. While matrix multiplication preserves the mean-to-variance ratio, the ReLU activation increases the mean and decreases the variance of its input distribution. Iterative application of these operations in deep networks leads to the observed decay in sample variance.
Empirical Validation
Finite-Width and Real-World Architectures
Empirical studies on finite-width MLPs confirm that the rate of sample variance decay is mitigated as width decreases, but the qualitative behavior persists. Experiments on ALL-CNN-C (CIFAR10) and UNet (ISBI2012) architectures further demonstrate that sample variance decay is present in practical convolutional networks, although architectural features such as global pooling and skip connections can modulate the effect.
Impact of Batch Normalization
Batch Normalization, by construction, enforces zero sample mean and fixed sample variance for each pre-activation at initialization. This intervention eliminates sample variance decay, placing the network in a regime where the gradients grow exponentially with depth. The theoretical prediction for the gradient explosion factor matches empirical observations, confirming the analysis.
Implications for Training and Initialization
Data-Dependent Initialization
The paper compares two initialization schemes:
- Scale-only (total variance preservation): Weights are scaled per-layer to preserve total variance; biases are set to zero.
- Scale+bias (sample variance preservation): Biases are set per-feature to zero the sample mean; weights are scaled to preserve sample variance.
Empirical results indicate that sample variance-preserving initialization (scale+bias) leads to faster training convergence compared to total variance-preserving initialization. The benefit persists well into training, suggesting that the network requires significant time to recover from poor initial sample statistics.
Gradient Dynamics
Preserving total variance (Kaiming) results in stable gradient norms across layers, but at the cost of driving the network into a nearly linear regime in deep layers, where pre-activations are dominated by large, input-independent means. In contrast, sample variance-preserving schemes (including BatchNorm) induce highly nonlinear behavior but at the expense of exponentially growing gradients.
Theoretical and Practical Implications
The analysis clarifies the distinction between total and sample variance preservation, highlighting that Kaiming initialization, while effective for shallow networks, induces undesirable representational collapse in deep ReLU networks. The findings provide a theoretical foundation for the empirical success of Batch Normalization and data-dependent initialization schemes, which maintain nontrivial sample variance and thus enable effective utilization of nonlinearities throughout the network.
The results also connect to the "ordered" and "chaotic" regimes described in prior work (Poole et al., 2016, Schoenholz et al., 2016), with Kaiming initialization corresponding to the ordered regime (stable gradients, linear behavior) and sample variance-preserving schemes corresponding to the chaotic regime (exploding gradients, nonlinear behavior).
Future Directions
The analysis is primarily restricted to wide, fully connected ReLU networks at initialization. Extending the theoretical framework to finite-width, convolutional, and residual architectures remains an open challenge. Additionally, the long-term impact of initialization-induced sample variance decay on generalization and representation learning warrants further investigation, particularly in the context of modern architectures that employ normalization and skip connections.
Conclusion
This work rigorously demonstrates that Kaiming initialization, while preserving total variance, leads to sample variance decay in deep ReLU networks, resulting in representational collapse and near-linearity in higher layers. Batch Normalization and data-dependent initialization schemes counteract this effect, preserving sample variance and enabling effective nonlinear computation at all depths, albeit with the trade-off of gradient explosion. These findings have significant implications for the design and initialization of deep neural networks, providing a theoretical basis for the widespread adoption of normalization techniques and motivating further research into initialization strategies that balance gradient stability and representational richness.