diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 61da5513cb1..cc72a49ebbb 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -495,30 +495,25 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): assert len(batch.choices) == 2 assert batch.choices[0].text == batch.choices[1].text - try: - # test n = 2 - batch = await client.completions.create( - model=model_name, - prompt=prompts, - n=2, - max_tokens=5, - temperature=0.0, - extra_body=dict( - # NOTE: this has to be true for n > 1 in vLLM, but - # not necessary for official client. - use_beam_search=True), - ) - assert len(batch.choices) == 4 - assert batch.choices[0].text != batch.choices[ - 1].text, "beam search should be different" - assert batch.choices[0].text == batch.choices[ - 2].text, "two copies of the same prompt should be the same" - assert batch.choices[1].text == batch.choices[ - 3].text, "two copies of the same prompt should be the same" - except BadRequestError as e: - # the only allowed exception is when beam search is not supported - # in the default mqllmengine - assert "--disable-frontend-multiprocessing" in str(e) + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=prompts, + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but + # not necessary for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" # test streaming batch = await client.completions.create( diff --git a/vllm/beam_search.py b/vllm/beam_search.py new file mode 100644 index 00000000000..04624b8b944 --- /dev/null +++ b/vllm/beam_search.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class BeamSearchSequence: + """A sequence for beam search. + It keeps track of the tokens and the log probability of the sequence. + The text field is optional and will only be filled when the sequence is + about to be returned to the user. + """ + # The tokens includes the prompt. + tokens: List[int] + cum_logprob: float = 0.0 + text: Optional[str] = None + + +@dataclass +class BeamSearchOutput: + """The output of beam search. + It contains the list of the best beam search sequences. + The length of the list is equal to the beam width. + """ + sequences: List[BeamSearchSequence] + + +class BeamSearchInstance: + + def __init__(self, prompt_tokens: List[int]): + self.beams: List[BeamSearchSequence] = [ + BeamSearchSequence(tokens=prompt_tokens) + ] + self.completed: List[BeamSearchSequence] = [] + + +def get_beam_search_score( + tokens: List[int], + cumulative_logprob: float, + eos_token_id: int, + length_penalty: float = 1.0, +) -> float: + """Calculate the beam search score with length penalty. + + Adapted from + + https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 + """ + seq_len = len(tokens) + if tokens[-1] == eos_token_id: + seq_len -= 1 + + return cumulative_logprob / (seq_len**length_penalty) + + +def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, + length_penalty) + + return sort_beams_key diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 50269493d64..30e1a09981c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,6 +7,7 @@ from weakref import ReferenceType import vllm.envs as envs +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs @@ -14,7 +15,6 @@ from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.metrics_types import StatLoggerBase -from vllm.entrypoints.llm import BeamSearchSequence from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster @@ -33,7 +33,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import (collect_from_async_generator, deprecate_kwargs, - get_beam_search_score, random_uuid, weak_bind) + random_uuid, weak_bind) logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -1052,16 +1052,14 @@ async def beam_search( temperature = params.temperature length_penalty = params.length_penalty - def sort_beams_key(x: BeamSearchSequence) -> float: - return get_beam_search_score(x.tokens, x.cum_logprob, - tokenizer.eos_token_id, - length_penalty) - tokenizer = await self.get_tokenizer() tokenizedPrompt = prompt if isinstance( prompt, list) else tokenizer.encode(prompt) tokenizedLength = len(tokenizedPrompt) + sort_beams_key = create_sort_beams_key_function( + tokenizer.eos_token_id, length_penalty) + beam_search_params = SamplingParams(logprobs=2 * beam_width, max_tokens=1, temperature=temperature) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index b0d061dbab4..820f678abef 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -2,8 +2,8 @@ import copy import pickle from contextlib import contextmanager, suppress -from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, - Union, overload) +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, + Optional, Union, overload) import cloudpickle import zmq @@ -12,6 +12,7 @@ from zmq.asyncio import Socket from vllm import PoolingParams +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block @@ -27,14 +28,16 @@ RPCUProfileRequest) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptType +from vllm.inputs import PromptType, TokensPrompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, + RequestOutput) from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.utils import deprecate_kwargs +from vllm.utils import (collect_from_async_generator, deprecate_kwargs, + random_uuid) logger = init_logger(__name__) @@ -441,6 +444,104 @@ def generate( lora_request, trace_headers, prompt_adapter_request, priority) + async def beam_search( + self, + prompt: Union[PromptType, List[int]], + request_id: str, + params: BeamSearchParams, + ) -> AsyncGenerator[RequestOutput, None]: + + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + + tokenizer = await self.get_tokenizer(lora_request=None) + tokenizedPrompt = prompt if isinstance( + prompt, list) else tokenizer.encode(prompt) + tokenizedLength = len(tokenizedPrompt) + + sort_beams_key = create_sort_beams_key_function( + tokenizer.eos_token_id, length_penalty) + + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature) + all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] + completed = [] + + for _ in range(max_tokens): + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + tasks = [] + + request_id = f"beam_search-{random_uuid()}" + for i, individual_prompt in enumerate(prompts_batch): + request_id_item = f"{request_id}-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.generate(individual_prompt, beam_search_params, + request_id_item))) + tasks.append(task) + + output = await asyncio.gather(*tasks) + + output = [x[0] for x in output] + + logger.info(output) + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + completed.append(new_beam) + else: + new_beams.append(new_beam) + + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) + + beam_search_output = RequestOutput( + request_id=request_id, + prompt=prompt, + outputs=[ + CompletionOutput( + text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens, + index=i, + logprobs=beam.cum_logprob, + ) for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=tokenizedPrompt, + prompt_logprobs=None) + + logger.info(beam_search_output) + + yield beam_search_output + @overload # DEPRECATED def encode( self, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 439f3769f9f..b0a8a66ec13 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,12 +1,13 @@ import itertools import warnings from contextlib import contextmanager -from dataclasses import dataclass from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Union, cast, overload) from tqdm import tqdm +from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, + BeamSearchSequence, get_beam_search_score) from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, @@ -28,43 +29,11 @@ get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext -from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score, - is_list_of) +from vllm.utils import Counter, deprecate_kwargs, is_list_of logger = init_logger(__name__) -@dataclass -class BeamSearchSequence: - """A sequence for beam search. - It keeps track of the tokens and the log probability of the sequence. - The text field is optional and will only be filled when the sequence is - about to be returned to the user. - """ - # The tokens includes the prompt. - tokens: List[int] - cum_logprob: float = 0.0 - text: Optional[str] = None - - -@dataclass -class BeamSearchOutput: - """The output of beam search. - It contains the list of the best beam search sequences. - The length of the list is equal to the beam width. - """ - sequences: List[BeamSearchSequence] - - -class BeamSearchInstance: - - def __init__(self, prompt_tokens: List[int]): - self.beams: List[BeamSearchSequence] = [ - BeamSearchSequence(tokens=prompt_tokens) - ] - self.completed: List[BeamSearchSequence] = [] - - class LLM: """An LLM for generating texts from given prompts and sampling parameters. diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c4652be6fe8..1e85167ea76 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -10,6 +10,7 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, apply_hf_chat_template, @@ -236,15 +237,16 @@ async def create_chat_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): - if not isinstance(self.engine_client, AsyncLLMEngine): - raise ValueError( - "Beam search in the API server is only supported with" - " AsyncLLMEngine. please add " - "`--disable-frontend-multiprocessing` to " - "use beam search.") + assert isinstance(self.engine_client, + (AsyncLLMEngine, + MQLLMEngineClient)), \ + "Beam search is only supported with" \ + "AsyncLLMEngine and MQLLMEngineClient." result_generator = self.engine_client.beam_search( - engine_inputs['prompt_token_ids'], request_id, - sampling_params) + engine_inputs['prompt_token_ids'], + request_id, + sampling_params, + ) else: result_generator = self.engine_client.generate( engine_inputs, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index bf9e9850797..077312dd141 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -9,6 +9,7 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block @@ -150,15 +151,16 @@ async def create_completion( log_tracing_disabled_warning() if isinstance(sampling_params, BeamSearchParams): - if not isinstance(self.engine_client, AsyncLLMEngine): - raise ValueError( - "Beam search in the API server is only supported" - " with AsyncLLMEngine. please add " - "`--disable-frontend-multiprocessing` to " - "use beam search.") + assert isinstance(self.engine_client, + (AsyncLLMEngine, + MQLLMEngineClient)), \ + "Beam search is only supported with" \ + "AsyncLLMEngine and MQLLMEngineClient." generator = self.engine_client.beam_search( - prompt_inputs["prompt_token_ids"], request_id_item, - sampling_params) + prompt_inputs["prompt_token_ids"], + request_id_item, + sampling_params, + ) else: generator = self.engine_client.generate( { diff --git a/vllm/utils.py b/vllm/utils.py index 1b7638c4a12..e44365fa249 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1361,22 +1361,3 @@ def dec(self, num=1): @property def value(self): return self._value - - -def get_beam_search_score( - tokens: List[int], - cumulative_logprob: float, - eos_token_id: int, - length_penalty: float = 1.0, -) -> float: - """Calculate the beam search score with length penalty. - - Adapted from - - https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 - """ - seq_len = len(tokens) - if tokens[-1] == eos_token_id: - seq_len -= 1 - - return cumulative_logprob / (seq_len**length_penalty)