- [2025-02-14] ✨ Code is now available.
- [2025-02-12] 📢 Our work is reported by both QbitAI (量子位) and AI Era (新智元).
- [2025-02-12] 🏅 Our paper ranked #1 on HuggingFace Daily Papers.
- [2025-02-11] 📄 Our paper is released on arXiv.
Clone the repository:
git clone https://github.com/RyanLiu112/compute-optimal-tts.git
cd compute-optimal-tts/src
Create a new conda environment and install the dependencies:
conda create -n tts python=3.10
conda activate tts
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
pip install "ray[default]==2.38.0"
pip install "fschat[model_worker,webui]"
pip install sympy==1.12
cd envs/MATH/latex2sympy
pip install -e .
Install tmux
for serving policy models and PRMs:
sudo apt-get update
sudo apt-get install tmux
Note
Our mathematical expression evaluation code is based on Qwen2.5-Math. For a more powerful evaluator, please refer to this repository: Math-Verify.
Llama series (Instruct):
Qwen series (Instruct):
DeepSeek-R1-Distill series:
- Math-Shepherd: Math-Shepherd-PRM-7B
- RLHFlow: RLHFlow-PRM-Mistral-8B, RLHFlow-PRM-Deepseek-8B
- Skywork: Skywork-PRM-1.5B, Skywork-PRM-7B
- Qwen2.5-Math: Qwen2.5-Math-PRM-7B, Qwen2.5-Math-PRM-72B
Policy Model | PRM | GPU |
---|---|---|
0.5B-14B | 1.5B-8B | 1x A100 80GB |
32B | 1.5B-8B | 2x A100 80GB |
72B | 1.5B-8B | 3x A100 80GB |
0.5B-32B | 72B | 3x A100 80GB |
72B | 72B | 4x A100 80GB |
Set the environment variables:
cd src
export VALUE_MODEL_PATH=path/to/RM # dummy for CoT
export POLICY_MODEL_PATH=path/to/LM && export LOGDIR=path/to/logdir
export HOST_ADDR=0.0.0.0 && export CONTROLLER_PORT=10014 && export WORKER_BASE_PORT=10081
Run the corresponding script:
# 1 gpu
bash scripts/serve_gpu1.sh $POLICY_MODEL_PATH $VALUE_MODEL_PATH $HOST_ADDR $CONTROLLER_PORT $WORKER_BASE_PORT
# 2 gpus (32B policy model + 1.5B-8B PRM)
bash scripts/serve_gpu2.sh $POLICY_MODEL_PATH $VALUE_MODEL_PATH $HOST_ADDR $CONTROLLER_PORT $WORKER_BASE_PORT
# 3 gpus (72B policy model + 1.5B-8B PRM)
bash scripts/serve_gpu3_1-2.sh $POLICY_MODEL_PATH $VALUE_MODEL_PATH $HOST_ADDR $CONTROLLER_PORT $WORKER_BASE_PORT
# 3 gpus (0.5B-32B policy model + 72B PRM)
bash scripts/serve_gpu3_2-1.sh $POLICY_MODEL_PATH $VALUE_MODEL_PATH $HOST_ADDR $CONTROLLER_PORT $WORKER_BASE_PORT
# 4 gpus (72B policy model + 72B PRM)
bash scripts/serve_gpu4.sh $POLICY_MODEL_PATH $VALUE_MODEL_PATH $HOST_ADDR $CONTROLLER_PORT $WORKER_BASE_PORT
We provide the following commands for different TTS methods.
cd src
bash scripts/run.sh --method cot --LM $POLICY_MODEL_PATH --RM dummy --width 1 --num_seq 1
Note
Configuring batch size for BoN and DVTS:
For instance, when running BoN on MATH-500, it processes 500 problems with each executing 256 times (determined by num_q
). To enhance the compute efficiency, it is recommended to distribute the problems across multiple GPUs by adjusting the batch size
(bs). For example, set bs to 500 for 256 GPUs or 16000 for 8 GPUs.
cd src
bash scripts/run.sh --method best_of_n --LM $POLICY_MODEL_PATH --RM $VALUE_MODEL_PATH --width 1 --num_seq 1 --num_q 256 --bs batch_size
cd src
bash scripts/run.sh --method beam_search --LM $POLICY_MODEL_PATH --RM $VALUE_MODEL_PATH --width 4 --num_seq 1
cd src
bash scripts/run.sh --method beam_search --LM $POLICY_MODEL_PATH --RM $VALUE_MODEL_PATH --width 4 --num_seq 1 --num_q 64 --bs batch_size
For BoN and DVTS, no average result is computed by default. To compute the average, aggregate the majority_vote
values from all jsonl files after processing all problems num_q
times.
If you find this work helpful, please kindly cite our paper:
@article{liu2025can,
title = {Can 1B LLM Surpass 405B LLM? Rethinking Compute-Optimal Test-Time Scaling},
author = {Runze Liu and Junqi Gao and Jian Zhao and Kaiyan Zhang and Xiu Li and Biqing Qi and Wanli Ouyang and Bowen Zhou},
journal = {arXiv preprint arXiv:2502.06703},
year = {2025}
}
Our code is largely based on OpenR, an awesome LLM reasoning repository, and their work has been instrumental in our study. Our mathematical expression evaluation code is based on Qwen2.5-Math. We also want to thank the community for providing high-quality open-source PRMs, including Qwen2.5-Math, Skywork-o1, RLHFlow, and Math-Shepherd.