Skip to content

Commit fedcd81

Browse files
LunrEclipsegarg-amit
authored andcommitted
[Frontend] API support for beam search for MQLLMEngine (vllm-project#9117)
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent ca840fd commit fedcd81

File tree

8 files changed

+215
-106
lines changed

8 files changed

+215
-106
lines changed

tests/entrypoints/openai/test_completion.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -495,30 +495,25 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
495495
assert len(batch.choices) == 2
496496
assert batch.choices[0].text == batch.choices[1].text
497497

498-
try:
499-
# test n = 2
500-
batch = await client.completions.create(
501-
model=model_name,
502-
prompt=prompts,
503-
n=2,
504-
max_tokens=5,
505-
temperature=0.0,
506-
extra_body=dict(
507-
# NOTE: this has to be true for n > 1 in vLLM, but
508-
# not necessary for official client.
509-
use_beam_search=True),
510-
)
511-
assert len(batch.choices) == 4
512-
assert batch.choices[0].text != batch.choices[
513-
1].text, "beam search should be different"
514-
assert batch.choices[0].text == batch.choices[
515-
2].text, "two copies of the same prompt should be the same"
516-
assert batch.choices[1].text == batch.choices[
517-
3].text, "two copies of the same prompt should be the same"
518-
except BadRequestError as e:
519-
# the only allowed exception is when beam search is not supported
520-
# in the default mqllmengine
521-
assert "--disable-frontend-multiprocessing" in str(e)
498+
# test n = 2
499+
batch = await client.completions.create(
500+
model=model_name,
501+
prompt=prompts,
502+
n=2,
503+
max_tokens=5,
504+
temperature=0.0,
505+
extra_body=dict(
506+
# NOTE: this has to be true for n > 1 in vLLM, but
507+
# not necessary for official client.
508+
use_beam_search=True),
509+
)
510+
assert len(batch.choices) == 4
511+
assert batch.choices[0].text != batch.choices[
512+
1].text, "beam search should be different"
513+
assert batch.choices[0].text == batch.choices[
514+
2].text, "two copies of the same prompt should be the same"
515+
assert batch.choices[1].text == batch.choices[
516+
3].text, "two copies of the same prompt should be the same"
522517

523518
# test streaming
524519
batch = await client.completions.create(

vllm/beam_search.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from dataclasses import dataclass
2+
from typing import List, Optional
3+
4+
5+
@dataclass
6+
class BeamSearchSequence:
7+
"""A sequence for beam search.
8+
It keeps track of the tokens and the log probability of the sequence.
9+
The text field is optional and will only be filled when the sequence is
10+
about to be returned to the user.
11+
"""
12+
# The tokens includes the prompt.
13+
tokens: List[int]
14+
cum_logprob: float = 0.0
15+
text: Optional[str] = None
16+
17+
18+
@dataclass
19+
class BeamSearchOutput:
20+
"""The output of beam search.
21+
It contains the list of the best beam search sequences.
22+
The length of the list is equal to the beam width.
23+
"""
24+
sequences: List[BeamSearchSequence]
25+
26+
27+
class BeamSearchInstance:
28+
29+
def __init__(self, prompt_tokens: List[int]):
30+
self.beams: List[BeamSearchSequence] = [
31+
BeamSearchSequence(tokens=prompt_tokens)
32+
]
33+
self.completed: List[BeamSearchSequence] = []
34+
35+
36+
def get_beam_search_score(
37+
tokens: List[int],
38+
cumulative_logprob: float,
39+
eos_token_id: int,
40+
length_penalty: float = 1.0,
41+
) -> float:
42+
"""Calculate the beam search score with length penalty.
43+
44+
Adapted from
45+
46+
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
47+
"""
48+
seq_len = len(tokens)
49+
if tokens[-1] == eos_token_id:
50+
seq_len -= 1
51+
52+
return cumulative_logprob / (seq_len**length_penalty)
53+
54+
55+
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
56+
57+
def sort_beams_key(x: BeamSearchSequence) -> float:
58+
return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id,
59+
length_penalty)
60+
61+
return sort_beams_key

vllm/engine/async_llm_engine.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
from weakref import ReferenceType
88

99
import vllm.envs as envs
10+
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
1011
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
1112
ParallelConfig, SchedulerConfig)
1213
from vllm.core.scheduler import SchedulerOutputs
1314
from vllm.engine.arg_utils import AsyncEngineArgs
1415
from vllm.engine.async_timeout import asyncio_timeout
1516
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
1617
from vllm.engine.metrics_types import StatLoggerBase
17-
from vllm.entrypoints.llm import BeamSearchSequence
1818
from vllm.executor.executor_base import ExecutorAsyncBase
1919
from vllm.executor.gpu_executor import GPUExecutorAsync
2020
from vllm.executor.ray_utils import initialize_ray_cluster
@@ -33,7 +33,7 @@
3333
from vllm.transformers_utils.tokenizer import AnyTokenizer
3434
from vllm.usage.usage_lib import UsageContext
3535
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
36-
get_beam_search_score, random_uuid, weak_bind)
36+
random_uuid, weak_bind)
3737

3838
logger = init_logger(__name__)
3939
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@@ -1052,16 +1052,14 @@ async def beam_search(
10521052
temperature = params.temperature
10531053
length_penalty = params.length_penalty
10541054

1055-
def sort_beams_key(x: BeamSearchSequence) -> float:
1056-
return get_beam_search_score(x.tokens, x.cum_logprob,
1057-
tokenizer.eos_token_id,
1058-
length_penalty)
1059-
10601055
tokenizer = await self.get_tokenizer()
10611056
tokenizedPrompt = prompt if isinstance(
10621057
prompt, list) else tokenizer.encode(prompt)
10631058
tokenizedLength = len(tokenizedPrompt)
10641059

1060+
sort_beams_key = create_sort_beams_key_function(
1061+
tokenizer.eos_token_id, length_penalty)
1062+
10651063
beam_search_params = SamplingParams(logprobs=2 * beam_width,
10661064
max_tokens=1,
10671065
temperature=temperature)

vllm/engine/multiprocessing/client.py

Lines changed: 107 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import copy
33
import pickle
44
from contextlib import contextmanager, suppress
5-
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
6-
Union, overload)
5+
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
6+
Optional, Union, overload)
77

88
import cloudpickle
99
import zmq
@@ -12,6 +12,7 @@
1212
from zmq.asyncio import Socket
1313

1414
from vllm import PoolingParams
15+
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
1516
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
1617
from vllm.engine.arg_utils import AsyncEngineArgs
1718
# yapf conflicts with isort for this block
@@ -27,14 +28,16 @@
2728
RPCUProfileRequest)
2829
# yapf: enable
2930
from vllm.envs import VLLM_RPC_TIMEOUT
30-
from vllm.inputs import PromptType
31+
from vllm.inputs import PromptType, TokensPrompt
3132
from vllm.logger import init_logger
3233
from vllm.lora.request import LoRARequest
33-
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
34+
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
35+
RequestOutput)
3436
from vllm.prompt_adapter.request import PromptAdapterRequest
35-
from vllm.sampling_params import SamplingParams
37+
from vllm.sampling_params import BeamSearchParams, SamplingParams
3638
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
37-
from vllm.utils import deprecate_kwargs
39+
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
40+
random_uuid)
3841

3942
logger = init_logger(__name__)
4043

@@ -441,6 +444,104 @@ def generate(
441444
lora_request, trace_headers,
442445
prompt_adapter_request, priority)
443446

447+
async def beam_search(
448+
self,
449+
prompt: Union[PromptType, List[int]],
450+
request_id: str,
451+
params: BeamSearchParams,
452+
) -> AsyncGenerator[RequestOutput, None]:
453+
454+
beam_width = params.beam_width
455+
max_tokens = params.max_tokens
456+
ignore_eos = params.ignore_eos
457+
temperature = params.temperature
458+
length_penalty = params.length_penalty
459+
460+
tokenizer = await self.get_tokenizer(lora_request=None)
461+
tokenizedPrompt = prompt if isinstance(
462+
prompt, list) else tokenizer.encode(prompt)
463+
tokenizedLength = len(tokenizedPrompt)
464+
465+
sort_beams_key = create_sort_beams_key_function(
466+
tokenizer.eos_token_id, length_penalty)
467+
468+
beam_search_params = SamplingParams(logprobs=2 * beam_width,
469+
max_tokens=1,
470+
temperature=temperature)
471+
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
472+
completed = []
473+
474+
for _ in range(max_tokens):
475+
prompts_batch = [
476+
TokensPrompt(prompt_token_ids=beam.tokens)
477+
for beam in all_beams
478+
]
479+
480+
tasks = []
481+
482+
request_id = f"beam_search-{random_uuid()}"
483+
for i, individual_prompt in enumerate(prompts_batch):
484+
request_id_item = f"{request_id}-{i}"
485+
task = asyncio.create_task(
486+
collect_from_async_generator(
487+
self.generate(individual_prompt, beam_search_params,
488+
request_id_item)))
489+
tasks.append(task)
490+
491+
output = await asyncio.gather(*tasks)
492+
493+
output = [x[0] for x in output]
494+
495+
logger.info(output)
496+
497+
new_beams = []
498+
for i, current_beam in enumerate(all_beams):
499+
result = output[i]
500+
501+
if result.outputs[0].logprobs is not None:
502+
logprobs = result.outputs[0].logprobs[0]
503+
for token_id, logprob_obj in logprobs.items():
504+
new_beam = BeamSearchSequence(
505+
tokens=current_beam.tokens + [token_id],
506+
cum_logprob=current_beam.cum_logprob +
507+
logprob_obj.logprob)
508+
509+
if token_id == tokenizer.eos_token_id and \
510+
not ignore_eos:
511+
completed.append(new_beam)
512+
else:
513+
new_beams.append(new_beam)
514+
515+
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
516+
all_beams = sorted_beams[:beam_width]
517+
518+
completed.extend(all_beams)
519+
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
520+
best_beams = sorted_completed[:beam_width]
521+
522+
for beam in best_beams:
523+
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
524+
525+
beam_search_output = RequestOutput(
526+
request_id=request_id,
527+
prompt=prompt,
528+
outputs=[
529+
CompletionOutput(
530+
text=beam.text,
531+
cumulative_logprob=beam.cum_logprob,
532+
token_ids=beam.tokens,
533+
index=i,
534+
logprobs=beam.cum_logprob,
535+
) for (i, beam) in enumerate(best_beams)
536+
],
537+
finished=True,
538+
prompt_token_ids=tokenizedPrompt,
539+
prompt_logprobs=None)
540+
541+
logger.info(beam_search_output)
542+
543+
yield beam_search_output
544+
444545
@overload # DEPRECATED
445546
def encode(
446547
self,

vllm/entrypoints/llm.py

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import itertools
22
import warnings
33
from contextlib import contextmanager
4-
from dataclasses import dataclass
54
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
65
Union, cast, overload)
76

87
from tqdm import tqdm
98

9+
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
10+
BeamSearchSequence, get_beam_search_score)
1011
from vllm.engine.arg_utils import EngineArgs
1112
from vllm.engine.llm_engine import LLMEngine
1213
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
@@ -28,43 +29,11 @@
2829
get_cached_tokenizer)
2930
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
3031
from vllm.usage.usage_lib import UsageContext
31-
from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
32-
is_list_of)
32+
from vllm.utils import Counter, deprecate_kwargs, is_list_of
3333

3434
logger = init_logger(__name__)
3535

3636

37-
@dataclass
38-
class BeamSearchSequence:
39-
"""A sequence for beam search.
40-
It keeps track of the tokens and the log probability of the sequence.
41-
The text field is optional and will only be filled when the sequence is
42-
about to be returned to the user.
43-
"""
44-
# The tokens includes the prompt.
45-
tokens: List[int]
46-
cum_logprob: float = 0.0
47-
text: Optional[str] = None
48-
49-
50-
@dataclass
51-
class BeamSearchOutput:
52-
"""The output of beam search.
53-
It contains the list of the best beam search sequences.
54-
The length of the list is equal to the beam width.
55-
"""
56-
sequences: List[BeamSearchSequence]
57-
58-
59-
class BeamSearchInstance:
60-
61-
def __init__(self, prompt_tokens: List[int]):
62-
self.beams: List[BeamSearchSequence] = [
63-
BeamSearchSequence(tokens=prompt_tokens)
64-
]
65-
self.completed: List[BeamSearchSequence] = []
66-
67-
6837
class LLM:
6938
"""An LLM for generating texts from given prompts and sampling parameters.
7039

vllm/entrypoints/openai/serving_chat.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from vllm.config import ModelConfig
1212
from vllm.engine.async_llm_engine import AsyncLLMEngine
13+
from vllm.engine.multiprocessing.client import MQLLMEngineClient
1314
from vllm.engine.protocol import EngineClient
1415
from vllm.entrypoints.chat_utils import (ConversationMessage,
1516
apply_hf_chat_template,
@@ -236,15 +237,16 @@ async def create_chat_completion(
236237
log_tracing_disabled_warning()
237238

238239
if isinstance(sampling_params, BeamSearchParams):
239-
if not isinstance(self.engine_client, AsyncLLMEngine):
240-
raise ValueError(
241-
"Beam search in the API server is only supported with"
242-
" AsyncLLMEngine. please add "
243-
"`--disable-frontend-multiprocessing` to "
244-
"use beam search.")
240+
assert isinstance(self.engine_client,
241+
(AsyncLLMEngine,
242+
MQLLMEngineClient)), \
243+
"Beam search is only supported with" \
244+
"AsyncLLMEngine and MQLLMEngineClient."
245245
result_generator = self.engine_client.beam_search(
246-
engine_inputs['prompt_token_ids'], request_id,
247-
sampling_params)
246+
engine_inputs['prompt_token_ids'],
247+
request_id,
248+
sampling_params,
249+
)
248250
else:
249251
result_generator = self.engine_client.generate(
250252
engine_inputs,

0 commit comments

Comments
 (0)