Papers
Topics
Authors
Recent
Search
2000 character limit reached

Enhancing Reasoning through Process Supervision with Monte Carlo Tree Search

Published 2 Jan 2025 in cs.AI, cs.CL, and cs.LG | (2501.01478v1)

Abstract: LLMs have demonstrated their remarkable capacity across a variety of tasks. However, reasoning remains a challenge for LLMs. To improve LLMs' reasoning ability, process supervision has proven to be better than outcome supervision. In this work, we study using Monte Carlo Tree Search (MCTS) to generate process supervision data with LLMs themselves for training them. We sample reasoning steps with an LLM and assign each step a score that captures its "relative correctness," and the LLM is then trained by minimizing weighted log-likelihood of generating the reasoning steps. This generate-then-train process is repeated iteratively until convergence.Our experimental results demonstrate that the proposed methods considerably improve the performance of LLMs on two mathematical reasoning datasets. Furthermore, models trained on one dataset also exhibit improved performance on the other, showing the transferability of the enhanced reasoning ability.

Summary

  • The paper demonstrates that integrating Monte Carlo Tree Search for process supervision significantly refines LLM reasoning.
  • The methodology involves iterative sampling of reasoning steps with scoring, weighted log-likelihood minimization, and a KL penalty to guide training.
  • Experimental results on MATH and GSM8K highlight rapid convergence and cross-dataset improvements in model performance.

Introduction

The paper "Enhancing Reasoning through Process Supervision with Monte Carlo Tree Search" investigates the improvement of reasoning capabilities in LLMs through process supervision rather than outcome supervision. The primary contribution of the work is the integration of Monte Carlo Tree Search (MCTS) to generate process supervision data for refining reasoning skills in LLMs. This research explores iteratively sampling reasoning steps using LLMs, assigning scores to each step to evaluate "relative correctness," and subsequently training the LLM by minimizing a weighted log-likelihood of these reasoning steps.

Methodology

The proposed method leverages MCTS to sample and search through step-by-step reasoning paths, constructing a training dataset of reasoning problems. The process involves four stages: Selection, Expansion, Simulation, and Backpropagation. The model assigns scores to each node in the reasoning path that reflect the relative correctness of the steps, using Eq. (1) from the paper. This score-based data is integrated into a weighted negative log-likelihood loss function, which, along with a KL penalty term, guides the training process. Figure 1

Figure 1: An overview of the proposed methods.

The iterative nature of the training ensures that with each cycle, the LLM becomes more adept at reasoning by continuously refining on previously generated data, starting from the pretrained LLM at the first iteration.

Experimental Setup and Results

Experiments are conducted using Llama-3.1-8B-Instruct and deepseek-math-7b-instruct on the MATH and GSM8K datasets. The evaluations compare the proposed approach against baselines such as Zero-shot-CoT and Rejection Sampling Fine-Tuning (RFT). The results indicate significant performance improvements with the proposed method consistently outperforming baselines. The iterative training showed rapid convergence, with noticeable accuracy gains on both datasets.

Transferability Evaluation

An additional experiment evaluated cross-dataset transfer capabilities, where models trained on one dataset exhibited improved performances on the other (e.g., models trained on GSM8K tested on MATH). The results confirmed that the models learned generalizable reasoning skills, albeit with lesser improvements than on the task-specific dataset.

Conclusion and Limitations

The integration of MCTS for generating process supervision data shows substantial efficacy in enhancing LLM reasoning abilities on mathematical datasets. However, limitations include the rapid convergence of the methodology that limits extensive iterative training. Additionally, models utilizing LoRA for training may have restricted the total improvements observed. Future work could explore methodologies for slowing convergence and evaluate alternative parameter-efficient fine-tuning methods.

This paper provides a robust framework for improving LLM reasoning with process supervision, opening avenues for further research into scalable reasoning improvement strategies in artificial intelligence.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 1 tweet with 0 likes about this paper.