Papers
Topics
Authors
Recent
Search
2000 character limit reached

ReMax: A Simple, Effective, and Efficient Reinforcement Learning Method for Aligning Large Language Models

Published 16 Oct 2023 in cs.LG | (2310.10505v4)

Abstract: Reinforcement Learning from Human Feedback (RLHF) is key to aligning LLMs, typically paired with the Proximal Policy Optimization (PPO) algorithm. While PPO is a powerful method designed for general reinforcement learning tasks, it is overly sophisticated for LLMs, leading to laborious hyper-parameter tuning and significant computation burdens. To make RLHF efficient, we present ReMax, which leverages 3 properties of RLHF: fast simulation, deterministic transitions, and trajectory-level rewards. These properties are not exploited in PPO, making it less suitable for RLHF. Building on the renowned REINFORCE algorithm, ReMax does not require training an additional value model as in PPO and is further enhanced with a new variance reduction technique. ReMax offers several benefits over PPO: it is simpler to implement, eliminates more than 4 hyper-parameters in PPO, reduces GPU memory usage, and shortens training time. ReMax can save about 46% GPU memory than PPO when training a 7B model and enables training on A800-80GB GPUs without the memory-saving offloading technique needed by PPO. Applying ReMax to a Mistral-7B model resulted in a 94.78% win rate on the AlpacaEval leaderboard and a 7.739 score on MT-bench, setting a new SOTA for open-source 7B models. These results show the effectiveness of ReMax while addressing the limitations of PPO in LLMs.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (46)
  1. Training a helpful and harmless assistant with reinforcement learning from human feedback. arXiv preprint arXiv:2204.05862, 2022a.
  2. Constitutional ai: Harmlessness from ai feedback. arXiv preprint arXiv:2212.08073, 2022b.
  3. M. Bartlett. Approximate confidence intervals. Biometrika, 40(1/2):12–19, 1953.
  4. On the opportunities and risks of foundation models. arXiv preprint arXiv:2108.07258, 2021.
  5. Language models are few-shot learners. Advances in Neural Information Processing Systems 33, pages 1877–1901, 2020.
  6. Provably efficient exploration in policy optimization. In Proceedings of the 37th International Conference on Machine Learning, pages 1283–1294, 2020.
  7. Open problems and fundamental limitations of reinforcement learning from human feedback. arXiv preprint arXiv:2307.15217, 2023.
  8. P. Dayan. Reinforcement comparison. In Connectionist Models, pages 45–51. Elsevier, 1991.
  9. Raft: Reward ranked finetuning for generative foundation model alignment. arXiv preprint arXiv:2304.06767, 2023.
  10. Alpacafarm: A simulation framework for methods that learn from human feedback. arXiv preprint arXiv:2305.14387, 2023.
  11. Implementation matters in deep policy gradients: A case study on ppo and trpo. arXiv preprint arXiv:2005.12729, 2020.
  12. Scaling laws for reward model overoptimization. In Proceedings of the 40th International Conference on Machine Learning, pages 10835–10866, 2023.
  13. Variance reduction techniques for gradient estimates in reinforcement learning. Journal of Machine Learning Research, 5(9), 2004.
  14. Lora: Low-rank adaptation of large language models. In Proceedings of the 10th International Conference on Learning Representations, 2022.
  15. Rlaif: Scaling reinforcement learning from human feedback with ai feedback. arXiv preprint arXiv:2309.00267, 2023.
  16. Alpacaeval: An automatic evaluator of instruction-following models. https://github.com/tatsu-lab/alpaca_eval, 2023.
  17. Monte carlo gradient estimation in machine learning. Journal of Machine Learning Research, 21(132):1–62, 2020.
  18. OpenAI. Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.
  19. Training language models to follow instructions with human feedback. Advances in Neural Information Processing Systems 35, pages 27730–27744, 2022.
  20. Reward gaming in conditional text generation. In Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics, pages 4746–4763, 2023.
  21. M. L. Puterman. Markov Decision Processes: Discrete Stochastic Dynamic Programming. John Wiley & Sons, 2014.
  22. Language models are unsupervised multitask learners. OpenAI blog, 2019.
  23. Direct preference optimization: Your language model is secretly a reward model. arXiv preprint arXiv:2305.18290, 2023.
  24. Zero: Memory optimizations toward training trillion parameter models. In SC20: International Conference for High Performance Computing, Networking, Storage and Analysis, pages 1–16, 2020.
  25. Is reinforcement learning (not) for natural language processing: Benchmarks, baselines, and building blocks for natural language policy optimization. In Proceedings of 11th International Conference on Learning Representations, 2023.
  26. {{\{{ZeRO-Offload}}\}}: Democratizing {{\{{Billion-Scale}}\}} model training. In Proceedings of the 2021 USENIX Annual Technical Conference, pages 551–564, 2021.
  27. Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter. arXiv preprint arXiv:1910.01108, 2019.
  28. Efficient rlhf: Reducing the memory usage of ppo. arXiv preprint arXiv:2309.00754, 2023.
  29. High-dimensional continuous control using generalized advantage estimation. In Proceedings of the 4th International Conference on Learning Representations, 2016.
  30. Proximal policy optimization algorithms. arXiv, 1707.06347, 2017.
  31. Mastering the game of go with deep neural networks and tree search. nature, 529(7587):484–489, 2016.
  32. A long way to go: Investigating length correlations in rlhf. arXiv preprint arXiv:2310.03716, 2023.
  33. Defining and characterizing reward gaming. In Advances in Neural Information Processing Systems 35, pages 9460–9471, 2022.
  34. Low-memory neural network training: A technical report. arXiv preprint arXiv:1904.10631, 2019.
  35. Learning to summarize with human feedback. Advances in Neural Information Processing Systems, 33:3008–3021, 2020.
  36. R. Sutton. Learning to predict by the methods of temporal differences. Machine learning, 3:9–44, 1988.
  37. Reinforcement Learning: An Introduction. MIT press, 2018.
  38. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
  39. L. Weaver and N. Tao. The optimal reward baseline for gradient-based reinforcement learning. In Proceedings of the 17th Conference in Uncertainty in Artificial Intelligence, pages 538–545, 2001.
  40. R. J. Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8:229–256, 1992.
  41. Deepspeed-chat: Easy, fast and affordable rlhf training of chatgpt-like models at all scales. arXiv preprint arXiv:2308.01320, 2023.
  42. Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.
  43. A survey of large language models. arXiv preprint arXiv:2303.18223, 2023.
  44. Secrets of rlhf in large language models part i: Ppo. arXiv preprint arXiv:2307.04964, 2023.
  45. Principled reinforcement learning with human feedback from pairwise or k-wise comparisons. In Proceedings of the 40th International Conference on Machine Learning, pages 43037–43067, 2023a.
  46. Fine-tuning language models with advantage-induced policy alignment. arXiv preprint arXiv:2306.02231, 2023b.
Citations (21)

Summary

  • The paper introduces ReMax, a reinforcement learning method that eliminates the value model to simplify LLM alignment and reduce computational overhead.
  • It employs variance reduction techniques from REINFORCE to stabilize training and cut GPU memory usage by nearly 46% on 7B models.
  • Empirical tests on Mistral-7B show that ReMax achieves competitive, state-of-the-art performance on benchmarks while improving training efficiency.

ReMax: A Simple, Effective, and Efficient Reinforcement Learning Method for Aligning LLMs

Introduction

"ReMax: A Simple, Effective, and Efficient Reinforcement Learning Method for Aligning LLMs" (2310.10505) introduces ReMax, a method designed to align LLMs more efficiently compared to the widely used Proximal Policy Optimization (PPO). The core idea exploits the intrinsic properties of Reinforcement Learning from Human Feedback (RLHF) for LLM training: fast simulation, deterministic transitions, and trajectory-level rewards.

Architectural Innovations

ReMax differentiates itself from PPO by eliminating the value model and instead relies on variance reduction techniques derived from the REINFORCE algorithm. This approach simplifies the implementation, substantially reduces computational overhead, and maintains the alignment efficacy in large-scale models.

ReMax Design:

  • Reference Model: Utilizes a reference model for KL-divergence penalties, similar to PPO.
  • Without Value Model: ReMax eliminates the need for a separate value model, reducing GPU memory requirements. Figure 1

    Figure 1: Building blocks of PPO and ReMax. ReMax keeps the reference model (for calculating KL penalty) and removes all the components related to the value model in PPO.

Computational Efficiency

ReMax is designed to address the demanding GPU resources and computation time traditionally associated with PPO. For instance, ReMax cuts down GPU memory consumption by approximately 46% when training a 7B model, and avoids the need for memory-saving offloading techniques. Figure 2

Figure 2: GPU memory consumption and training time by PPO and ReMax, respectively. These measurements are conducted on a Llama-2-7B model using A800-80GB GPUs.

Reinforcement Learning Adaptation

ReMax employs a reward-weighted likelihood maximization approach, rooted in the REINFORCE algorithm. This avoids the complexity inherent in PPO's advantage estimation and off-policy corrections. Instead, ReMax introduces a novel variance reduction technique that leverages the deterministic environment of LLMs under RLHF.

Key Properties:

  • Fast Simulation: LL models quickly generate complete responses, optimizing turnaround time.
  • Deterministic Transitions: LLM environments are predictable, reducing noise in training.
  • Trajectory-level Rewards: Focus on complete output evaluation simplifies the reward structure. Figure 3

    Figure 3: Illustration of StarCraft II (a general RL task example) and RLHF in LLMs.

Variance Reduction and Stability

The variance resulting from stochastic gradients in REINFORCE is mitigated in ReMax by incorporating a greedy baseline reward. This method maintains unbiased gradient estimation while reducing variability, which is pivotal in stable training of large-scale models.

Algorithm Implementation:

1
2
3
4
5
6
7
8
9
def remax_align(language_model, reward_model, dataset):
    for prompt in dataset:
        seq = language_model.sample(prompt, greedy=False)
        seq_max = language_model.sample(prompt, greedy=True)
        rewd_diff = reward_model.evaluate(seq) - reward_model.evaluate(seq_max)
        log_prob = language_model.logits(prompt, seq)
        loss = -torch.mean(torch.sum(log_prob, dim=-1) * rewd_diff)
        language_model.optimize(loss)
    return language_model

Practical Implications

In empirical tests, ReMax demonstrated competitive or superior performance compared to PPO, thanks to more consistent hyperparameter tuning and training stability. When applied to Mistral-7B, ReMax set a new state-of-the-art (SOTA) for open-source 7B models on the AlpacaEval and MT-bench leaderboards.

Conclusion

ReMax represents a pragmatic evolution in aligning LLMs through RLHF, reducing resource burden and improving accessibility to training large models. This methodological approach may encourage widespread adoption and refinement, given its balance of simplicity, efficiency, and effectiveness in real-world applications.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 23 tweets with 83 likes about this paper.