Skip to content

Commit 739b61a

Browse files
[Frontend] Refactor prompt processing (#4028)
Co-authored-by: Roger Wang <ywang@roblox.com>
1 parent 89c1c6a commit 739b61a

File tree

24 files changed

+698
-390
lines changed

24 files changed

+698
-390
lines changed

benchmarks/benchmark_latency.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from vllm import LLM, SamplingParams
1313
from vllm.engine.arg_utils import EngineArgs
14-
from vllm.inputs import PromptStrictInputs
14+
from vllm.inputs import PromptInputs
1515
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1616
from vllm.utils import FlexibleArgumentParser
1717

@@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
6161
dummy_prompt_token_ids = np.random.randint(10000,
6262
size=(args.batch_size,
6363
args.input_len))
64-
dummy_inputs: List[PromptStrictInputs] = [{
64+
dummy_inputs: List[PromptInputs] = [{
6565
"prompt_token_ids": batch
6666
} for batch in dummy_prompt_token_ids.tolist()]
6767

docs/source/dev/multimodal/multimodal_index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Multi-Modality
88
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
99

1010
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
11-
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`.
11+
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
1212

1313
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
1414
by following :ref:`this guide <adding_multimodal_plugin>`.

docs/source/dev/offline_inference/llm_inputs.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
LLM Inputs
22
==========
33

4-
.. autodata:: vllm.inputs.PromptStrictInputs
4+
.. autodata:: vllm.inputs.PromptInputs
55

66
.. autoclass:: vllm.inputs.TextPrompt
77
:show-inheritance:

docs/source/models/vlm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
3030
internally for each model.
3131

3232

33-
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
33+
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
3434

3535
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
3636
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

tests/engine/output_processor/test_stop_checker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def sequence_with_eos(text: str, eos_token: str,
3535
@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
3636
("This text ends with EOS token", "</s>", 2),
3737
])
38-
@pytest.mark.parametrize("ignore_eos", [True, False, None])
39-
@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None])
38+
@pytest.mark.parametrize("ignore_eos", [True, False])
39+
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
4040
@pytest.mark.skip_global_cleanup
4141
def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
4242
ignore_eos: bool, include_stop_str_in_output: bool):

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ async def _async_serving_chat_init():
3232
model_config,
3333
served_model_names=[MODEL_NAME],
3434
response_role="assistant",
35-
chat_template=CHAT_TEMPLATE)
35+
chat_template=CHAT_TEMPLATE,
36+
lora_modules=None,
37+
prompt_adapters=None,
38+
request_logger=None)
3639
return serving_completion
3740

3841

vllm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from vllm.engine.llm_engine import LLMEngine
66
from vllm.entrypoints.llm import LLM
77
from vllm.executor.ray_utils import initialize_ray_cluster
8-
from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt
8+
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
99
from vllm.model_executor.models import ModelRegistry
1010
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
1111
EmbeddingRequestOutput, RequestOutput)
@@ -19,7 +19,7 @@
1919
"__version__",
2020
"LLM",
2121
"ModelRegistry",
22-
"PromptStrictInputs",
22+
"PromptInputs",
2323
"TextPrompt",
2424
"TokensPrompt",
2525
"SamplingParams",

vllm/engine/arg_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,6 @@ class AsyncEngineArgs(EngineArgs):
827827
"""Arguments for asynchronous vLLM engine."""
828828
engine_use_ray: bool = False
829829
disable_log_requests: bool = False
830-
max_log_len: Optional[int] = None
831830

832831
@staticmethod
833832
def add_cli_args(parser: FlexibleArgumentParser,
@@ -841,12 +840,6 @@ def add_cli_args(parser: FlexibleArgumentParser,
841840
parser.add_argument('--disable-log-requests',
842841
action='store_true',
843842
help='Disable logging requests.')
844-
parser.add_argument('--max-log-len',
845-
type=int,
846-
default=None,
847-
help='Max number of prompt characters or prompt '
848-
'ID numbers being printed in log.'
849-
'\n\nDefault: Unlimited')
850843
return parser
851844

852845

vllm/engine/async_llm_engine.py

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import asyncio
22
import time
33
from functools import partial
4-
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
5-
Set, Tuple, Type, Union)
4+
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
5+
Optional, Set, Tuple, Type, Union)
66

77
from transformers import PreTrainedTokenizer
88

@@ -151,7 +151,10 @@ def process_exception(self,
151151
logger.info("Finished request %s.", request_id)
152152
self.abort_request(request_id)
153153

154-
def add_request(self, request_id: str,
154+
def add_request(self,
155+
request_id: str,
156+
*,
157+
verbose: bool = False,
155158
**engine_add_request_kwargs) -> AsyncStream:
156159
"""Add a request to be sent to the engine on the next background
157160
loop iteration."""
@@ -166,6 +169,9 @@ def add_request(self, request_id: str,
166169

167170
self.new_requests_event.set()
168171

172+
if verbose:
173+
logger.info("Added request %s.", request_id)
174+
169175
return stream
170176

171177
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
@@ -299,14 +305,14 @@ async def process_model_inputs_async(
299305
return self.input_processor(llm_inputs)
300306

301307
async def add_request_async(
302-
self,
303-
request_id: str,
304-
inputs: PromptInputs,
305-
params: Union[SamplingParams, PoolingParams],
306-
arrival_time: Optional[float] = None,
307-
lora_request: Optional[LoRARequest] = None,
308-
trace_headers: Optional[Dict[str, str]] = None,
309-
prompt_adapter_request: Optional[PromptAdapterRequest] = None
308+
self,
309+
request_id: str,
310+
inputs: PromptInputs,
311+
params: Union[SamplingParams, PoolingParams],
312+
arrival_time: Optional[float] = None,
313+
lora_request: Optional[LoRARequest] = None,
314+
trace_headers: Optional[Mapping[str, str]] = None,
315+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
310316
) -> None:
311317
if lora_request is not None and not self.lora_config:
312318
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@@ -353,8 +359,6 @@ class AsyncLLMEngine:
353359
async frontend will be executed in a separate process as the
354360
model workers.
355361
log_requests: Whether to log the requests.
356-
max_log_len: Maximum number of prompt characters or prompt ID numbers
357-
being printed in log.
358362
start_engine_loop: If True, the background task to run the engine
359363
will be automatically started in the generate call.
360364
*args: Arguments for :class:`LLMEngine`.
@@ -368,13 +372,11 @@ def __init__(self,
368372
engine_use_ray: bool,
369373
*args,
370374
log_requests: bool = True,
371-
max_log_len: Optional[int] = None,
372375
start_engine_loop: bool = True,
373376
**kwargs) -> None:
374377
self.worker_use_ray = worker_use_ray
375378
self.engine_use_ray = engine_use_ray
376379
self.log_requests = log_requests
377-
self.max_log_len = max_log_len
378380
self.engine = self._init_engine(*args, **kwargs)
379381

380382
self.background_loop: Optional[asyncio.Future] = None
@@ -468,7 +470,6 @@ def from_engine_args(
468470
executor_class=executor_class,
469471
log_requests=not engine_args.disable_log_requests,
470472
log_stats=not engine_args.disable_log_stats,
471-
max_log_len=engine_args.max_log_len,
472473
start_engine_loop=start_engine_loop,
473474
usage_context=usage_context,
474475
stat_loggers=stat_loggers,
@@ -667,30 +668,9 @@ async def add_request(
667668
params: Union[SamplingParams, PoolingParams],
668669
arrival_time: Optional[float] = None,
669670
lora_request: Optional[LoRARequest] = None,
670-
trace_headers: Optional[Dict[str, str]] = None,
671+
trace_headers: Optional[Mapping[str, str]] = None,
671672
prompt_adapter_request: Optional[PromptAdapterRequest] = None
672673
) -> AsyncStream:
673-
if self.log_requests:
674-
if isinstance(inputs, str):
675-
shortened_prompt = inputs
676-
shortened_token_ids = None
677-
else:
678-
shortened_prompt = inputs.get("prompt")
679-
shortened_token_ids = inputs.get("prompt_token_ids")
680-
681-
max_log_len = self.max_log_len
682-
if max_log_len is not None:
683-
if shortened_prompt is not None:
684-
shortened_prompt = shortened_prompt[:max_log_len]
685-
if shortened_token_ids is not None:
686-
shortened_token_ids = shortened_token_ids[:max_log_len]
687-
688-
logger.info(
689-
"Received request %s: prompt: %r, "
690-
"params: %s, prompt_token_ids: %s, "
691-
"lora_request: %s.", request_id, shortened_prompt, params,
692-
shortened_token_ids, lora_request)
693-
694674
if not self.is_running:
695675
if self.start_engine_loop:
696676
self.start_background_loop()
@@ -706,6 +686,7 @@ async def add_request(
706686

707687
stream = self._request_tracker.add_request(
708688
request_id,
689+
verbose=self.log_requests,
709690
inputs=inputs,
710691
params=params,
711692
arrival_time=arrival_time,
@@ -721,7 +702,7 @@ async def generate(
721702
sampling_params: SamplingParams,
722703
request_id: str,
723704
lora_request: Optional[LoRARequest] = None,
724-
trace_headers: Optional[Dict[str, str]] = None,
705+
trace_headers: Optional[Mapping[str, str]] = None,
725706
prompt_adapter_request: Optional[PromptAdapterRequest] = None
726707
) -> AsyncIterator[RequestOutput]:
727708
"""Generate outputs for a request.
@@ -804,7 +785,7 @@ async def encode(
804785
pooling_params: PoolingParams,
805786
request_id: str,
806787
lora_request: Optional[LoRARequest] = None,
807-
trace_headers: Optional[Dict[str, str]] = None,
788+
trace_headers: Optional[Mapping[str, str]] = None,
808789
) -> AsyncIterator[EmbeddingRequestOutput]:
809790
"""Generate outputs for a request from an embedding model.
810791
@@ -882,7 +863,7 @@ async def _process_request(
882863
params: Union[SamplingParams, PoolingParams],
883864
*,
884865
lora_request: Optional[LoRARequest] = None,
885-
trace_headers: Optional[Dict[str, str]] = None,
866+
trace_headers: Optional[Mapping[str, str]] = None,
886867
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
887868
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
888869
"""Common logic to process requests with SamplingParams or

vllm/engine/llm_engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import time
22
from contextlib import contextmanager
3-
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
3+
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
4+
Mapping, Optional)
45
from typing import Sequence as GenericSequence
56
from typing import Set, Type, TypeVar, Union
67

@@ -522,7 +523,7 @@ def _add_processed_request(
522523
arrival_time: float,
523524
lora_request: Optional[LoRARequest],
524525
prompt_adapter_request: Optional[PromptAdapterRequest],
525-
trace_headers: Optional[Dict[str, str]] = None,
526+
trace_headers: Optional[Mapping[str, str]] = None,
526527
) -> None:
527528
# Create the sequences.
528529
block_size = self.cache_config.block_size
@@ -603,7 +604,7 @@ def add_request(
603604
params: Union[SamplingParams, PoolingParams],
604605
arrival_time: Optional[float] = None,
605606
lora_request: Optional[LoRARequest] = None,
606-
trace_headers: Optional[Dict[str, str]] = None,
607+
trace_headers: Optional[Mapping[str, str]] = None,
607608
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
608609
) -> None:
609610
"""Add a request to the engine's request pool.
@@ -677,7 +678,7 @@ def _create_sequence_group_with_sampling(
677678
sampling_params: SamplingParams,
678679
arrival_time: float,
679680
lora_request: Optional[LoRARequest],
680-
trace_headers: Optional[Dict[str, str]] = None,
681+
trace_headers: Optional[Mapping[str, str]] = None,
681682
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
682683
) -> SequenceGroup:
683684
"""Creates a SequenceGroup with SamplingParams."""

0 commit comments

Comments
 (0)