- The paper demonstrates that recursive training on generated data systematically erodes the model's ability to represent true data distributions, resulting in collapse.
- Analytical and empirical analyses using GMMs, VAEs, and LLMs quantify how both statistical and functional approximation errors contribute to performance degradation.
- The study highlights the importance of preserving original data to maintain accurate modeling and prevent degradation across successive training generations.
The paper investigates the implications of training generative models, specifically LLMs, on data generated by previous iterations of similar models. The central finding is the discovery of a "model collapse," a degenerative process where models progressively lose the ability to represent the true underlying data distribution, with the tails of the distribution disappearing over time. This phenomenon is shown to occur in Gaussian Mixture Models (GMM), Variational Autoencoders (VAE), and LLMs, suggesting its ubiquity across learned generative models.
The authors identify two primary causes of model collapse:
- Statistical approximation error, which arises due to the finite number of samples used during training.
- Functional approximation error, stemming from limitations in the expressiveness of the function approximators used in the models.
The paper argues that access to the original data distribution is crucial for sustaining the benefits of training from large-scale data, especially for capturing low-probability events that are often relevant to marginalized groups and understanding complex systems. The authors propose that data about genuine human interactions with systems will be increasingly valuable in the presence of content generated by LLM in data crawled from the Internet.
The paper presents a theoretical analysis of model collapse, using simplified mathematical models to provide analytical expressions for quantities of interest. The analysis focuses on quantifying how different sources of error affect the overall approximation of the original distribution. The authors consider two cases: a discrete distribution in the absence of functional approximation error, and a single-dimensional Gaussian case that portrays how functional approximation error can compound with statistical error.
Key theoretical results include:
- Demonstration that for discrete distributions with exact approximation, model collapse arises solely due to statistical errors from the sampling step, leading to the eventual convergence to a delta function.
- Derivation of a lower bound on the risk, defined in terms of the Wasserstein distance from the true distribution, for a single-dimensional Gaussian. The risk diverges linearly with the number of generations, indicating that the sampling rate needs to increase superlinearly to maintain an accurate approximation of the original distribution.
The paper also presents empirical results that support the theoretical analysis. Specifically, the authors demonstrate model collapse in GMMs and VAEs trained from scratch, showing that the models progressively lose information about the tails of the distribution and converge to a distribution with very small variance.
In the context of LLMs, the paper investigates the effects of fine-tuning OPT-125m on data generated by previous iterations of the model. The results show that models trained on generated data exhibit degraded performance compared to models trained on original data. The generated data also exhibit longer tails, suggesting that the models are starting to misperceive reality based on errors introduced by their ancestors.
The authors conduct experiments with different training regimes, including training for 5 epochs with no original training data and training for 10 epochs with 10% of the original training data preserved. Both regimes lead to degraded performance, but the preservation of original data allows for better model fine-tuning and leads to only minor degradation of performance.
The paper also addresses the issue of repeating phrases in generated text, showing that explicitly encouraging models to produce non-repeating sequences does not curb the effects of model collapse.
The paper concludes by discussing the implications of model collapse for the long-term sustainability of LLM training. The authors emphasize the importance of preserving access to the original data source and distinguishing data generated by LLM from other data. They suggest that community-wide coordination may be necessary to ensure the provenance of content crawled from the Internet and to enable the training of newer versions of LLM without access to pre-LLM data or direct human-generated data.
In the theoretical analysis, the authors model the learning process with generational data as a stochastic process. At generation i, the dataset D_i consists of i.i.d. random variables Xi_j, where j∈{1,…,M_i} and M_i≥2. The distribution of Xi is denoted as p_i, with p_0 representing the original distribution. The transition from generation i to i+1 involves estimating the distribution of samples in D_i with an approximation p_θ_i+1, where F_θ:p_i→p_θ_i+1 represents the functional approximation. The dataset D_i+1 is then resampled from the distribution p_i+1=α_ip_θ_i+1+β_ip_i+γ_ip_0, with non-negative parameters α_i,β_i,γ_i summing up to $1$.
For the single dimensional Gaussian case, the authors consider X0∼N(μ,σ2) and estimate the sample mean and variance using:
μ_i+1=M_i1∑_jXi_j
- μ_i+1 is the estimated sample mean at generation i+1
- M_i is the sample size at generation i
- Xi_j represents the samples at generation i
σ2_i+1=M_i−11∑_j(Xi_j−μ_i+1)2
- σ2_i+1 is the estimated sample variance at generation i+1
They then derive the following expression for Xn_j:
$X^n\_j = \mu + \frac{\sigma}{\sqrt{M\_0}Z^1} + \frac{\sigma}{\sqrt{M\_1}\sqrt{S^1}Z^2} + \dots + \frac{\sigma}{\sqrt{M\_{n-1}\sqrt{S^1\times\dots\times S^{n-1}Z^n+\sigma\sqrt{S^1\times\dots\times S^{n}Z^n\_j}$
- Zi are random variables distributed as N(0,1)
- Si are random variables distributed as M_i−1−11Γ(2M_i−1−1,21)
They derive the following approximation:
Var(Xn_j)=σ2(1+Mn)
The authors then use the Wasserstein-2 distance to measure the distance between the true distribution and the approximated distribution at step n+1:
Rn+1_W_2:=W2_2(N(μ,σ2),N(μ_n+1,σ2_n+1))=∥μ_n+1−μ∥2+∥σ_n+1−σ∥2
Finally, they calculate the risk as:
E_μ_n+1,σ_n+12[Rn+1_W_2]=σ2(M_01+M_11+⋯+2M_n3)