Skip to content

Commit b11c696

Browse files
youkaichaoLunrEclipse
authored andcommitted
re-implement beam search on top of vllm core (vllm-project#8726)
Co-authored-by: Brendan Wong <bjwpokemon@gmail.com> Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent 5b81eee commit b11c696

File tree

4 files changed

+171
-9
lines changed

4 files changed

+171
-9
lines changed

benchmarks/benchmark_throughput.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def run_vllm(
9090
download_dir: Optional[str] = None,
9191
load_format: str = EngineArgs.load_format,
9292
disable_async_output_proc: bool = False,
93+
use_new_beam_search_impl: bool = False,
9394
) -> float:
9495
from vllm import LLM, SamplingParams
9596
llm = LLM(
@@ -132,9 +133,23 @@ def run_vllm(
132133
max_tokens=output_len,
133134
))
134135

135-
start = time.perf_counter()
136-
llm.generate(prompts, sampling_params, use_tqdm=True)
137-
end = time.perf_counter()
136+
if not use_new_beam_search_impl:
137+
start = time.perf_counter()
138+
llm.generate(prompts, sampling_params, use_tqdm=True)
139+
end = time.perf_counter()
140+
else:
141+
assert use_beam_search
142+
prompts = [prompt for prompt, _, _ in requests]
143+
# output_len should be the same for all requests.
144+
output_len = requests[0][2]
145+
for prompt, input_len, _output_len in requests:
146+
assert _output_len == output_len
147+
start = time.perf_counter()
148+
llm.beam_search(prompts,
149+
beam_width=n,
150+
max_tokens=output_len,
151+
ignore_eos=True)
152+
end = time.perf_counter()
138153
return end - start
139154

140155

@@ -336,7 +351,7 @@ def main(args: argparse.Namespace):
336351
run_args.append(args.disable_frontend_multiprocessing)
337352
elapsed_time = uvloop.run(run_vllm_async(*run_args))
338353
else:
339-
elapsed_time = run_vllm(*run_args)
354+
elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl)
340355
elif args.backend == "hf":
341356
assert args.tensor_parallel_size == 1
342357
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -396,6 +411,7 @@ def main(args: argparse.Namespace):
396411
default=1,
397412
help="Number of generated sequences per prompt.")
398413
parser.add_argument("--use-beam-search", action="store_true")
414+
parser.add_argument("--use-new-beam-search-impl", action="store_true")
399415
parser.add_argument("--num-prompts",
400416
type=int,
401417
default=1000,

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,20 @@ def generate_beam_search(
798798
outputs = self.generate(prompts, beam_search_params)
799799
return outputs
800800

801+
def generate_beam_search_new(
802+
self,
803+
prompts: Union[List[str], List[List[int]]],
804+
beam_width: int,
805+
max_tokens: int,
806+
) -> List[Tuple[List[List[int]], List[str]]]:
807+
outputs = self.model.beam_search(prompts, beam_width, max_tokens)
808+
returned_outputs = []
809+
for output in outputs:
810+
token_ids = [x.tokens for x in output.sequences]
811+
texts = [x.text for x in output.sequences]
812+
returned_outputs.append((token_ids, texts))
813+
return returned_outputs
814+
801815
def encode(self, prompts: List[str]) -> List[List[float]]:
802816
req_outputs = self.model.encode(prompts)
803817
outputs = []

tests/samplers/test_beam_search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# 1. Increase max_tokens to 256.
1010
# 2. Increase beam_width to 8.
1111
# 3. Use the model "huggyllama/llama-7b".
12-
MAX_TOKENS = [128]
12+
MAX_TOKENS = [64]
1313
BEAM_WIDTHS = [4]
1414
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
1515

@@ -33,8 +33,8 @@ def test_beam_search_single_input(
3333
max_tokens)
3434

3535
with vllm_runner(model, dtype=dtype) as vllm_model:
36-
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
37-
beam_width, max_tokens)
36+
vllm_outputs = vllm_model.generate_beam_search_new(
37+
example_prompts, beam_width, max_tokens)
3838

3939
for i in range(len(example_prompts)):
4040
hf_output_ids, hf_output_texts = hf_outputs[i]

vllm/entrypoints/llm.py

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import itertools
12
from contextlib import contextmanager
2-
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
3-
overload)
3+
from dataclasses import dataclass
4+
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
5+
Union, cast, overload)
46

57
from tqdm import tqdm
68

@@ -30,6 +32,37 @@
3032
logger = init_logger(__name__)
3133

3234

35+
@dataclass
36+
class BeamSearchSequence:
37+
"""A sequence for beam search.
38+
It keeps track of the tokens and the log probability of the sequence.
39+
The text field is optional and will only be filled when the sequence is
40+
about to be returned to the user.
41+
"""
42+
# The tokens includes the prompt.
43+
tokens: List[int]
44+
cum_logprob: float = 0.0
45+
text: Optional[str] = None
46+
47+
48+
@dataclass
49+
class BeamSearchOutput:
50+
"""The output of beam search.
51+
It contains the list of the best beam search sequences.
52+
The length of the list is equal to the beam width.
53+
"""
54+
sequences: List[BeamSearchSequence]
55+
56+
57+
class BeamSearchInstance:
58+
59+
def __init__(self, prompt_tokens: List[int]):
60+
self.beams: List[BeamSearchSequence] = [
61+
BeamSearchSequence(tokens=prompt_tokens)
62+
]
63+
self.completed: List[BeamSearchSequence] = []
64+
65+
3366
class LLM:
3467
"""An LLM for generating texts from given prompts and sampling parameters.
3568
@@ -354,6 +387,105 @@ def generate(
354387
outputs = self._run_engine(use_tqdm=use_tqdm)
355388
return LLMEngine.validate_outputs(outputs, RequestOutput)
356389

390+
def beam_search(
391+
self,
392+
prompts: List[Union[str, List[int]]],
393+
beam_width: int,
394+
max_tokens: int,
395+
ignore_eos: bool = False,
396+
) -> List[BeamSearchOutput]:
397+
"""
398+
Generate sequences using beam search.
399+
400+
Args:
401+
prompts: A list of prompts. Each prompt can be a string or a list
402+
of token IDs.
403+
beam_width: The number of beams to keep at each step.
404+
max_tokens: The max number of tokens to generate for each prompt.
405+
406+
TODO: how does beam search work together with length penalty, frequency
407+
penalty, and stopping criteria, etc.?
408+
"""
409+
410+
tokenizer = self.get_tokenizer()
411+
# generate 2 * beam_width candidates at each step
412+
# following the huggingface transformers implementation
413+
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
414+
beam_search_params = SamplingParams(logprobs=2 * beam_width,
415+
max_tokens=1,
416+
temperature=0.0)
417+
instances: List[BeamSearchInstance] = []
418+
419+
for prompt in prompts:
420+
prompt_tokens = prompt if isinstance(
421+
prompt, list) else tokenizer.encode(prompt)
422+
instances.append(BeamSearchInstance(prompt_tokens))
423+
424+
for _ in range(max_tokens):
425+
all_beams: List[BeamSearchSequence] = list(
426+
sum((instance.beams for instance in instances), []))
427+
pos = [0] + list(
428+
itertools.accumulate(
429+
len(instance.beams) for instance in instances))
430+
instance_start_and_end: List[Tuple[int, int]] = list(
431+
zip(pos[:-1], pos[1:]))
432+
433+
if len(all_beams) == 0:
434+
break
435+
436+
prompts_batch = [
437+
TokensPrompt(prompt_token_ids=beam.tokens)
438+
for beam in all_beams
439+
]
440+
441+
# only runs for one step
442+
# we don't need to use tqdm here
443+
output = self.generate(prompts_batch,
444+
sampling_params=beam_search_params,
445+
use_tqdm=False)
446+
447+
for (start, end), instance in zip(instance_start_and_end,
448+
instances):
449+
instance_new_beams = []
450+
for i in range(start, end):
451+
current_beam = all_beams[i]
452+
result = output[i]
453+
454+
if result.outputs[0].logprobs is not None:
455+
# if `result.outputs[0].logprobs` is None, it means
456+
# the sequence is completed because of the max-model-len
457+
# or abortion. we don't need to add it to the new beams.
458+
logprobs = result.outputs[0].logprobs[0]
459+
for token_id, logprob_obj in logprobs.items():
460+
new_beam = BeamSearchSequence(
461+
tokens=current_beam.tokens + [token_id],
462+
cum_logprob=current_beam.cum_logprob +
463+
logprob_obj.logprob)
464+
465+
if token_id == tokenizer.eos_token_id and \
466+
not ignore_eos:
467+
instance.completed.append(new_beam)
468+
else:
469+
instance_new_beams.append(new_beam)
470+
sorted_beams = sorted(instance_new_beams,
471+
key=lambda x: x.cum_logprob,
472+
reverse=True)
473+
instance.beams = sorted_beams[:beam_width]
474+
475+
outputs = []
476+
for instance in instances:
477+
instance.completed.extend(instance.beams)
478+
sorted_completed = sorted(instance.completed,
479+
key=lambda x: x.cum_logprob,
480+
reverse=True)
481+
best_beams = sorted_completed[:beam_width]
482+
483+
for beam in best_beams:
484+
beam.text = tokenizer.decode(beam.tokens)
485+
outputs.append(BeamSearchOutput(sequences=best_beams))
486+
487+
return outputs
488+
357489
def chat(
358490
self,
359491
messages: List[ChatCompletionMessageParam],

0 commit comments

Comments
 (0)