Skip to content

Commit 1704dc9

Browse files
DarkLight1337ywang96
authored andcommitted
[Frontend] Refactor prompt processing (vllm-project#4028)
Co-authored-by: Roger Wang <ywang@roblox.com> Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent 1c69636 commit 1704dc9

File tree

24 files changed

+698
-390
lines changed

24 files changed

+698
-390
lines changed

benchmarks/benchmark_latency.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from vllm import LLM, SamplingParams
1313
from vllm.engine.arg_utils import EngineArgs
14-
from vllm.inputs import PromptStrictInputs
14+
from vllm.inputs import PromptInputs
1515
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1616
from vllm.utils import FlexibleArgumentParser
1717

@@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
6161
dummy_prompt_token_ids = np.random.randint(10000,
6262
size=(args.batch_size,
6363
args.input_len))
64-
dummy_inputs: List[PromptStrictInputs] = [{
64+
dummy_inputs: List[PromptInputs] = [{
6565
"prompt_token_ids": batch
6666
} for batch in dummy_prompt_token_ids.tolist()]
6767

docs/source/dev/multimodal/multimodal_index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Multi-Modality
88
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
99

1010
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
11-
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`.
11+
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
1212

1313
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
1414
by following :ref:`this guide <adding_multimodal_plugin>`.

docs/source/dev/offline_inference/llm_inputs.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
LLM Inputs
22
==========
33

4-
.. autodata:: vllm.inputs.PromptStrictInputs
4+
.. autodata:: vllm.inputs.PromptInputs
55

66
.. autoclass:: vllm.inputs.TextPrompt
77
:show-inheritance:

docs/source/models/vlm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
3030
internally for each model.
3131

3232

33-
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
33+
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
3434

3535
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
3636
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

tests/engine/output_processor/test_stop_checker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def sequence_with_eos(text: str, eos_token: str,
3535
@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
3636
("This text ends with EOS token", "</s>", 2),
3737
])
38-
@pytest.mark.parametrize("ignore_eos", [True, False, None])
39-
@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None])
38+
@pytest.mark.parametrize("ignore_eos", [True, False])
39+
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
4040
@pytest.mark.skip_global_cleanup
4141
def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
4242
ignore_eos: bool, include_stop_str_in_output: bool):

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ async def _async_serving_chat_init():
3232
model_config,
3333
served_model_names=[MODEL_NAME],
3434
response_role="assistant",
35-
chat_template=CHAT_TEMPLATE)
35+
chat_template=CHAT_TEMPLATE,
36+
lora_modules=None,
37+
prompt_adapters=None,
38+
request_logger=None)
3639
return serving_completion
3740

3841

vllm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from vllm.engine.llm_engine import LLMEngine
66
from vllm.entrypoints.llm import LLM
77
from vllm.executor.ray_utils import initialize_ray_cluster
8-
from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt
8+
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
99
from vllm.model_executor.models import ModelRegistry
1010
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
1111
EmbeddingRequestOutput, RequestOutput)
@@ -19,7 +19,7 @@
1919
"__version__",
2020
"LLM",
2121
"ModelRegistry",
22-
"PromptStrictInputs",
22+
"PromptInputs",
2323
"TextPrompt",
2424
"TokensPrompt",
2525
"SamplingParams",

vllm/engine/arg_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,6 @@ class AsyncEngineArgs(EngineArgs):
827827
"""Arguments for asynchronous vLLM engine."""
828828
engine_use_ray: bool = False
829829
disable_log_requests: bool = False
830-
max_log_len: Optional[int] = None
831830

832831
@staticmethod
833832
def add_cli_args(parser: FlexibleArgumentParser,
@@ -841,12 +840,6 @@ def add_cli_args(parser: FlexibleArgumentParser,
841840
parser.add_argument('--disable-log-requests',
842841
action='store_true',
843842
help='Disable logging requests.')
844-
parser.add_argument('--max-log-len',
845-
type=int,
846-
default=None,
847-
help='Max number of prompt characters or prompt '
848-
'ID numbers being printed in log.'
849-
'\n\nDefault: Unlimited')
850843
return parser
851844

852845

vllm/engine/async_llm_engine.py

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import asyncio
22
import time
33
from functools import partial
4-
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
5-
Set, Tuple, Type, Union)
4+
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
5+
Optional, Set, Tuple, Type, Union)
66

77
from transformers import PreTrainedTokenizer
88

@@ -151,7 +151,10 @@ def process_exception(self,
151151
logger.info("Finished request %s.", request_id)
152152
self.abort_request(request_id)
153153

154-
def add_request(self, request_id: str,
154+
def add_request(self,
155+
request_id: str,
156+
*,
157+
verbose: bool = False,
155158
**engine_add_request_kwargs) -> AsyncStream:
156159
"""Add a request to be sent to the engine on the next background
157160
loop iteration."""
@@ -166,6 +169,9 @@ def add_request(self, request_id: str,
166169

167170
self.new_requests_event.set()
168171

172+
if verbose:
173+
logger.info("Added request %s.", request_id)
174+
169175
return stream
170176

171177
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
@@ -337,14 +343,14 @@ async def process_model_params_async(
337343
return params
338344

339345
async def add_request_async(
340-
self,
341-
request_id: str,
342-
inputs: PromptInputs,
343-
params: Union[SamplingParams, PoolingParams],
344-
arrival_time: Optional[float] = None,
345-
lora_request: Optional[LoRARequest] = None,
346-
trace_headers: Optional[Dict[str, str]] = None,
347-
prompt_adapter_request: Optional[PromptAdapterRequest] = None
346+
self,
347+
request_id: str,
348+
inputs: PromptInputs,
349+
params: Union[SamplingParams, PoolingParams],
350+
arrival_time: Optional[float] = None,
351+
lora_request: Optional[LoRARequest] = None,
352+
trace_headers: Optional[Mapping[str, str]] = None,
353+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
348354
) -> None:
349355
if lora_request is not None and not self.lora_config:
350356
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@@ -393,8 +399,6 @@ class AsyncLLMEngine:
393399
async frontend will be executed in a separate process as the
394400
model workers.
395401
log_requests: Whether to log the requests.
396-
max_log_len: Maximum number of prompt characters or prompt ID numbers
397-
being printed in log.
398402
start_engine_loop: If True, the background task to run the engine
399403
will be automatically started in the generate call.
400404
*args: Arguments for :class:`LLMEngine`.
@@ -408,13 +412,11 @@ def __init__(self,
408412
engine_use_ray: bool,
409413
*args,
410414
log_requests: bool = True,
411-
max_log_len: Optional[int] = None,
412415
start_engine_loop: bool = True,
413416
**kwargs) -> None:
414417
self.worker_use_ray = worker_use_ray
415418
self.engine_use_ray = engine_use_ray
416419
self.log_requests = log_requests
417-
self.max_log_len = max_log_len
418420
self.engine = self._init_engine(*args, **kwargs)
419421

420422
self.background_loop: Optional[asyncio.Future] = None
@@ -508,7 +510,6 @@ def from_engine_args(
508510
executor_class=executor_class,
509511
log_requests=not engine_args.disable_log_requests,
510512
log_stats=not engine_args.disable_log_stats,
511-
max_log_len=engine_args.max_log_len,
512513
start_engine_loop=start_engine_loop,
513514
usage_context=usage_context,
514515
stat_loggers=stat_loggers,
@@ -707,30 +708,9 @@ async def add_request(
707708
params: Union[SamplingParams, PoolingParams],
708709
arrival_time: Optional[float] = None,
709710
lora_request: Optional[LoRARequest] = None,
710-
trace_headers: Optional[Dict[str, str]] = None,
711+
trace_headers: Optional[Mapping[str, str]] = None,
711712
prompt_adapter_request: Optional[PromptAdapterRequest] = None
712713
) -> AsyncStream:
713-
if self.log_requests:
714-
if isinstance(inputs, str):
715-
shortened_prompt = inputs
716-
shortened_token_ids = None
717-
else:
718-
shortened_prompt = inputs.get("prompt")
719-
shortened_token_ids = inputs.get("prompt_token_ids")
720-
721-
max_log_len = self.max_log_len
722-
if max_log_len is not None:
723-
if shortened_prompt is not None:
724-
shortened_prompt = shortened_prompt[:max_log_len]
725-
if shortened_token_ids is not None:
726-
shortened_token_ids = shortened_token_ids[:max_log_len]
727-
728-
logger.info(
729-
"Received request %s: prompt: %r, "
730-
"params: %s, prompt_token_ids: %s, "
731-
"lora_request: %s.", request_id, shortened_prompt, params,
732-
shortened_token_ids, lora_request)
733-
734714
if not self.is_running:
735715
if self.start_engine_loop:
736716
self.start_background_loop()
@@ -746,6 +726,7 @@ async def add_request(
746726

747727
stream = self._request_tracker.add_request(
748728
request_id,
729+
verbose=self.log_requests,
749730
inputs=inputs,
750731
params=params,
751732
arrival_time=arrival_time,
@@ -761,7 +742,7 @@ async def generate(
761742
sampling_params: SamplingParams,
762743
request_id: str,
763744
lora_request: Optional[LoRARequest] = None,
764-
trace_headers: Optional[Dict[str, str]] = None,
745+
trace_headers: Optional[Mapping[str, str]] = None,
765746
prompt_adapter_request: Optional[PromptAdapterRequest] = None
766747
) -> AsyncIterator[RequestOutput]:
767748
"""Generate outputs for a request.
@@ -844,7 +825,7 @@ async def encode(
844825
pooling_params: PoolingParams,
845826
request_id: str,
846827
lora_request: Optional[LoRARequest] = None,
847-
trace_headers: Optional[Dict[str, str]] = None,
828+
trace_headers: Optional[Mapping[str, str]] = None,
848829
) -> AsyncIterator[EmbeddingRequestOutput]:
849830
"""Generate outputs for a request from an embedding model.
850831
@@ -922,7 +903,7 @@ async def _process_request(
922903
params: Union[SamplingParams, PoolingParams],
923904
*,
924905
lora_request: Optional[LoRARequest] = None,
925-
trace_headers: Optional[Dict[str, str]] = None,
906+
trace_headers: Optional[Mapping[str, str]] = None,
926907
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
927908
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
928909
"""Common logic to process requests with SamplingParams or

vllm/engine/llm_engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import time
22
from contextlib import contextmanager
3-
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
3+
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
4+
Mapping, Optional)
45
from typing import Sequence as GenericSequence
56
from typing import Set, Type, TypeVar, Union
67

@@ -525,7 +526,7 @@ def _add_processed_request(
525526
arrival_time: float,
526527
lora_request: Optional[LoRARequest],
527528
prompt_adapter_request: Optional[PromptAdapterRequest],
528-
trace_headers: Optional[Dict[str, str]] = None,
529+
trace_headers: Optional[Mapping[str, str]] = None,
529530
) -> None:
530531
# Create the sequences.
531532
block_size = self.cache_config.block_size
@@ -643,7 +644,7 @@ def add_request(
643644
params: Union[SamplingParams, PoolingParams],
644645
arrival_time: Optional[float] = None,
645646
lora_request: Optional[LoRARequest] = None,
646-
trace_headers: Optional[Dict[str, str]] = None,
647+
trace_headers: Optional[Mapping[str, str]] = None,
647648
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
648649
) -> None:
649650
"""Add a request to the engine's request pool.
@@ -721,7 +722,7 @@ def _create_sequence_group_with_sampling(
721722
sampling_params: SamplingParams,
722723
arrival_time: float,
723724
lora_request: Optional[LoRARequest],
724-
trace_headers: Optional[Dict[str, str]] = None,
725+
trace_headers: Optional[Mapping[str, str]] = None,
725726
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
726727
) -> SequenceGroup:
727728
"""Creates a SequenceGroup with SamplingParams."""

0 commit comments

Comments
 (0)