Skip to content

Commit b462585

Browse files
DarkLight1337Alvant
authored andcommitted
[Core] Factor out input preprocessing to a separate class (vllm-project#7329)
1 parent ba7ee4f commit b462585

File tree

5 files changed

+589
-537
lines changed

5 files changed

+589
-537
lines changed

tests/engine/test_skip_tokenizer_init.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ def test_skip_tokenizer_initialization(model: str):
1111
# token ids.
1212
llm = LLM(model=model, skip_tokenizer_init=True)
1313
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
14-
with pytest.raises(ValueError) as err:
14+
15+
with pytest.raises(ValueError, match="cannot pass text prompts when"):
1516
llm.generate("abc", sampling_params)
16-
assert "prompts must be None if" in str(err.value)
17+
1718
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
1819
sampling_params=sampling_params)
1920
assert len(outputs) > 0

vllm/engine/async_llm_engine.py

Lines changed: 3 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,17 @@
44
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
55
Mapping, Optional, Set, Tuple, Type, Union)
66

7-
from typing_extensions import assert_never
8-
97
import vllm.envs as envs
108
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
119
ParallelConfig, SchedulerConfig)
1210
from vllm.core.scheduler import SchedulerOutputs
1311
from vllm.engine.arg_utils import AsyncEngineArgs
1412
from vllm.engine.async_timeout import asyncio_timeout
15-
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
16-
PromptComponents, SchedulerOutputState)
13+
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
1714
from vllm.engine.metrics_types import StatLoggerBase
1815
from vllm.executor.executor_base import ExecutorAsyncBase
1916
from vllm.executor.ray_utils import initialize_ray_cluster
20-
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
21-
SingletonPromptInputs)
22-
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
17+
from vllm.inputs import PromptInputs
2318
from vllm.logger import init_logger
2419
from vllm.lora.request import LoRARequest
2520
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -403,139 +398,6 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
403398
"""Stop the remote worker execution loop."""
404399
await self.model_executor.stop_remote_worker_execution_loop_async()
405400

406-
async def _tokenize_prompt_async(
407-
self,
408-
prompt: str,
409-
request_id: str,
410-
lora_request: Optional[LoRARequest],
411-
) -> List[int]:
412-
"""Async version of :meth:`_tokenize_prompt`."""
413-
tokenizer = self.get_tokenizer_group(
414-
missing_msg="prompts must be None if skip_tokenizer_init is True")
415-
416-
return await tokenizer.encode_async(request_id=request_id,
417-
prompt=prompt,
418-
lora_request=lora_request)
419-
420-
async def _extract_prompt_components_async(
421-
self,
422-
inputs: SingletonPromptInputs,
423-
request_id: str,
424-
lora_request: Optional[LoRARequest] = None,
425-
) -> PromptComponents:
426-
"""Async version of :meth:`_extract_prompt_components`."""
427-
if isinstance(inputs, str):
428-
prompt = inputs
429-
prompt_token_ids = await self._tokenize_prompt_async(
430-
prompt,
431-
request_id=request_id,
432-
lora_request=lora_request,
433-
)
434-
multi_modal_data = None
435-
elif isinstance(inputs, dict):
436-
if "prompt_token_ids" in inputs:
437-
prompt = None
438-
prompt_token_ids = inputs["prompt_token_ids"]
439-
else:
440-
# NOTE: This extra assignment is required to pass mypy
441-
prompt = parsed_prompt = inputs["prompt"]
442-
prompt_token_ids = await self._tokenize_prompt_async(
443-
parsed_prompt,
444-
request_id=request_id,
445-
lora_request=lora_request,
446-
)
447-
448-
multi_modal_data = inputs.get("multi_modal_data")
449-
else:
450-
assert_never(inputs)
451-
452-
return prompt, prompt_token_ids, multi_modal_data
453-
454-
async def _process_encoder_decoder_prompt_async(
455-
self,
456-
inputs: PromptInputs,
457-
request_id: str,
458-
) -> EncoderDecoderLLMInputs:
459-
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
460-
encoder_comps: PromptComponents
461-
decoder_comps: DecoderPromptComponents
462-
463-
if is_explicit_encoder_decoder_prompt(inputs):
464-
encoder_task = self._extract_prompt_components_async(
465-
inputs["encoder_prompt"],
466-
request_id=request_id,
467-
)
468-
469-
if (decoder_input := inputs["decoder_prompt"]) is None:
470-
encoder_comps = await encoder_task
471-
decoder_comps = None, None, None
472-
else:
473-
decoder_task = self._extract_prompt_components_async(
474-
decoder_input,
475-
request_id=request_id,
476-
)
477-
478-
encoder_comps, decoder_comps = await asyncio.gather(
479-
encoder_task, decoder_task)
480-
else:
481-
encoder_comps = await self._extract_prompt_components_async(
482-
inputs,
483-
request_id=request_id,
484-
)
485-
486-
decoder_comps = None, None, None
487-
488-
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
489-
490-
async def _process_decoder_only_prompt_async(
491-
self,
492-
inputs: SingletonPromptInputs,
493-
request_id: str,
494-
lora_request: Optional[LoRARequest] = None,
495-
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
496-
) -> LLMInputs:
497-
"""Async version of :meth:`_process_decoder_only_prompt`."""
498-
prompt_comps = await self._extract_prompt_components_async(
499-
inputs,
500-
request_id=request_id,
501-
lora_request=lora_request,
502-
)
503-
504-
return self._build_decoder_only_llm_inputs(
505-
prompt_comps,
506-
prompt_adapter_request=prompt_adapter_request,
507-
)
508-
509-
async def process_model_inputs_async(
510-
self,
511-
inputs: PromptInputs,
512-
request_id: str,
513-
lora_request: Optional[LoRARequest] = None,
514-
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
515-
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
516-
"""Async version of :meth:`process_model_inputs`."""
517-
if self.is_encoder_decoder_model():
518-
# Encoder-decoder model requires special mapping of
519-
# input prompts to encoder & decoder
520-
model_inputs = await self._process_encoder_decoder_prompt_async(
521-
inputs,
522-
request_id=request_id,
523-
)
524-
else:
525-
if is_explicit_encoder_decoder_prompt(inputs):
526-
raise ValueError("Cannot pass encoder-decoder prompt "
527-
"to decoder-only models")
528-
529-
# Decoder-only operation
530-
model_inputs = await self._process_decoder_only_prompt_async(
531-
inputs,
532-
request_id=request_id,
533-
lora_request=lora_request,
534-
prompt_adapter_request=prompt_adapter_request,
535-
)
536-
537-
return self.input_processor(model_inputs)
538-
539401
async def process_model_params_async(
540402
self,
541403
request_id: str,
@@ -591,7 +453,7 @@ async def add_request_async(
591453
if arrival_time is None:
592454
arrival_time = time.time()
593455

594-
processed_inputs = await self.process_model_inputs_async(
456+
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
595457
inputs,
596458
request_id=request_id,
597459
lora_request=lora_request,

0 commit comments

Comments
 (0)