Papers
Topics
Authors
Recent
Search
2000 character limit reached

Towards Fully FP8 GEMM LLM Training at Scale

Published 26 May 2025 in cs.LG | (2505.20524v1)

Abstract: Despite the significant potential of FP8 data formats for LLM pre-training, their adoption has been limited due to challenges in maintaining stability at scale. Existing approaches often rely on suboptimal fine-grained FP8 kernels or fall back to higher-precision matrix multiplications (GEMMs) in sensitive components, such as attention projections, compromising potential throughput gains. We introduce a new class of LLM architectures that, for the first time, support FP8 computation for all GEMMs within transformer blocks during both forward and backward passes. This enables unprecedented throughput gains, particularly at scale, while matching the downstream performance of standard BF16 training. Our architecture design reduces large outlier activations, promoting stable long-term FP8 training. In addition, we identify key metrics to monitor low-precision training and predict potential future divergences.

Summary

  • The paper presents a novel FOG architecture that employs FP8 computations in all GEMMs to boost throughput by up to 40% compared to BF16 training.
  • It utilizes techniques like RMS normalization and QK entropy regularization to counteract the numerical instabilities inherent in low-precision FP8 formats.
  • Experimental results across multiple model scales confirm that FOG architectures maintain robust training with sustained low activation kurtosis levels.

Towards Fully FP8 GEMM LLM Training at Scale

Introduction

The introduction of FP8 data formats for large-scale LLM training has demonstrated the potential for significant efficiency improvements by reducing computational resources. However, challenges such as stability issues during training have constrained its adoption. Existing approaches often incorporate higher-precision computations for sensitive parts of LLM architectures, which limits throughput gains. This paper presents an innovative class of transformer architectures capable of performing FP8 computations across all GEMMs within the transformer blocks, achieving unprecedented throughput improvements up to 40% without sacrificing downstream performance compared to BF16 training.

FP8 Training Challenges

FP8 formats are appealing due to their lower precision, which can accelerate computations on suitable hardware accelerators. Nevertheless, the narrow dynamic range of FP8 causes higher susceptibility to numerical instability, such as underflows and overflows. Activation functions in transformers, particularly those that produce large outlier features, exacerbate this risk and often require careful scaling techniques when transitioning from higher precision formats. The introduction of FOG architectures aims to address these outlier-related instabilities by modifying standard transformer designs to promote stable FP8 training throughout all stages of development. Figure 1

Figure 1: From OP to step by step. The first run to diverge is OP, while OP with frozen QK RMSNorm survives the stable phase but diverges during learning rate cooldown. The converged run adds post-normalization.

FOG Architecture Design

FOG architectures prioritize preventing large activation outliers to stabilize FP8 training. By altering existing transformer architectures to omit pre-normalization blocks and incorporating methods like RMS normalization and QK entropy regularization to prevent entropy collapse, they maintain signal propagation to effectively minimize training divergences. These adjustments support the utilization of FP8s across more components than previously feasible. Figure 2

Figure 2: Kurtosis of QKV tensors during FP8DPA learning rate cooldown with OP+frozenQK architecture. Later layers show significantly larger activation outliers.

Long-term Outlier Dynamics

Kurtosis serves as a metric for assessing the extremity of deviations due to outliers in activation values throughout training. By monitoring kurtosis, which quantifies outlier prevalence in key activations of transformers, FOG ensures long-term FP8 stability. Notably, sustained low kurtosis levels during extended training regimes indicate reduced susceptibility to outlier-induced divergences, supporting the robustness claimed by the architecture. Figure 3

Figure 3: Loss and kurtosis training dynamics comparing models with a standard GLU-based vs. pre-normalized activation variants over 100B tokens.

Experimental Results

Extensive experiments validate the proposed architecture across varying scales and baseline adjustments. Conducted over several model sizes (0.4B, 1.5B, and 8B), results affirm the comparable performance of FOG architectures against BF16 standards, emphasizing throughput improvements up to 40%. Training was executed using FP8 computations for attention and final projection layers, showcasing significant gains in efficiency. Figure 4

Figure 4: Training dynamics of failed vs. successful FP8DPA runs. Kurtosis diverging earlier than loss demonstrates its utility for prediction.

Limitations and Future Work

Though advancing FP8 utilization within transformers, FOG architectures still require BF16 for specific sensitive computations like final projections. Future advancements could explore more generalized FP8 applications and optimizer state adaptations to further lower precision memory burdens without compromising performance. Figure 5

Figure 5: Cross-entropy loss plots across different architectures over long-data training regimes, showing FOG's capacity to train continuously without divergence.

Conclusion

FOG architectures represent a significant advancement in enabling FP8 LLM training across scales without sacrificing computational performance. This practical approach not only improves efficiency but provides a framework for enhancing stability in long-term training regimes, positioning FP8 formats as viable contenders for widespread LLM applications. Figure 6

Figure 6: Long-data training regimes with FP8DPA scaled progressively against established higher precision benchmarks.

The research presented establishes the viability of FP8 for extensive LLM training, paving the way for further exploration of architecturally supportive components and techniques to enhance FP8 adoption in the future.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 5 tweets with 399 likes about this paper.