Papers
Topics
Authors
Recent
Search
2000 character limit reached

Training LLMs with MXFP4

Published 27 Feb 2025 in cs.LG | (2502.20586v3)

Abstract: Low precision (LP) datatypes such as MXFP4 can accelerate matrix multiplications (GEMMs) and reduce training costs. However, directly using MXFP4 instead of BF16 during training significantly degrades model quality. In this work, we present the first near-lossless training recipe that uses MXFP4 GEMMs, which are $2\times$ faster than FP8 on supported hardware. Our key insight is to compute unbiased gradient estimates with stochastic rounding (SR), resulting in more accurate model updates. However, directly applying SR to MXFP4 can result in high variance from block-level outliers, harming convergence. To overcome this, we use the random Hadamard tranform to theoretically bound the variance of SR. We train GPT models up to 6.7B parameters and find that our method induces minimal degradation over mixed-precision BF16 training. Our recipe computes $>1/2$ the training FLOPs in MXFP4, enabling an estimated speedup of $>1.3\times$ over FP8 and $>1.7\times$ over BF16 during backpropagation.

Summary

  • The paper introduces a novel MXFP4 training recipe that combines stochastic rounding and Random Hadamard Transform to produce unbiased gradients for efficient LLM training.
  • Empirical tests on GPT models up to 6.7B parameters show speedups greater than 1.3× over FP8 and 1.7× over BF16, with negligible differences in perplexity.
  • The methodology paves the way for using low precision datatypes to reduce computational costs while maintaining high model performance in large-scale deployments.

Efficient Low Precision Training with MXFP4

The paper entitled "Training LLMs with MXFP4" (2502.20586) explores techniques to leverage low precision datatype, specifically MXFP4, for efficient training of LLMs. This methodology proposes significant speedup during training by mitigating the degradation of model quality typically associated with low precision formats.

Introduction to MXFP4 for LLM Training

MXFP4 is a low precision datatype used to accelerate training by reducing the computational cost of General Matrix Multiplications (GEMMs). Despite its potential for speedup, direct utilization of MXFP4 can degrade model quality when replacing higher precision formats like BF16 during training. To overcome this limitation, the paper introduces a novel training recipe for MXFP4 that maintains model quality comparably to BF16 while offering significant speed advantages.

The methodology revolves around two core innovations: the use of stochastic rounding (SR) to compute unbiased gradients and the Random Hadamard Transform (RHT) to bound the variance introduced by stochastic rounding (Figure 1). Figure 1

Figure 1: Our method uses stochastic rounding (SR) to compute unbiased gradients and the random Hadamard transform to bound the variance of SR. This enables us to perform more accurate model updates with MXFP4 in the backward pass, enabling a speedup of >1.3\times over FP8 and >1.7\times over BF16.

Unbiased Gradient Estimates with MXFP4

Standard MXFP4 quantization can introduce bias, particularly when dealing with high variance or block-level outlier data. To address this challenge, the authors apply unbiased stochastic rounding which computes gradient estimates with less distortion. The SR technique ensures each element is rounded in a manner that preserves the expectation of the quantity, reducing variance due to random noise-induced in quantization.

Moreover, the application of the Random Hadamard Transform (RHT) effectively manages variance and keeps stochastic rounding's impact on convergence in check. RHT is pivotal in concentrating the gradient distributions, aiding in managing the quantization variance, thus promoting a more stable training process.

Experimental Results

The authors tested their approach on GPT models with parameter sizes up to 6.7B, demonstrating minimal model degradation when comparing MXFP4 training to BF16 mixed precision training. Validation curves illustrate that models trained with MXFP4 and RHT closely follow those trained using BF16, revealing negligible discrepancies in perplexity (Figures 3-5). Figure 2

Figure 2: GPT 345M validation perplexity curves with BF16 forward pass. With RHT and SR, MXFP4 can match the performance of BF16 in the backward pass.

Figure 3

Figure 3: GPT 1.3B validation perplexity curves with BF16 forward pass. With RHT and SR, MXFP4 can match the performance of BF16 in the backward pass.

Figure 4

Figure 4: GPT 6.7B validation perplexity curves with BF16 forward pass. With RHT and SR, MXFP4 can match the performance of BF16 in the backward pass. The MXFP4-only run was stopped early to save resources.

Empirically, the technique secures a training speedup over FP8 by a factor greater than 1.3 and over BF16 by more than 1.7 in backpropagation operations, confirming substantial efficiency gains without sacrificing model integrity.

Implications and Future Work

The proposed method offers a pathway to leverage low precision formats effectively in LLM training by eliminating precision-related drawbacks typically encountered during training. With the continuous advancement in hardware supporting low precision operations, the presented approach holds promise for reducing computational resource demands in large-scale model training.

Future work could investigate further optimizations in hardware specific implementations or extend this technique to other low precision datatypes within the microscaling family to explore their potential benefits in LLM training pipelines.

Conclusion

"Training LLMs with MXFP4" introduces novel methods that effectively harness MXFP4, ensuring near-lossless training while substantially accelerating the process compared to traditional higher precision formats. Emphasizing the blend of stochastic rounding and the Random Hadamard Transform, this approach stands as a leading candidate for efficient training frameworks of increasingly large and computationally expensive models.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Authors (3)

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 18 tweets with 255 likes about this paper.

HackerNews

  1. Training LLMs with MXFP4 (2 points, 0 comments)