Skip to content

[Misc] Clean up input processing #17582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,10 +497,6 @@ async def add_request_async(
prompt["prompt_token_ids"] = [0
] * prompt["prompt_embeds"].shape[-2]

if self.tokenizer is not None:
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)

processed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
lora_request=lora_request,
Expand Down
32 changes: 5 additions & 27 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
Expand Down Expand Up @@ -759,11 +759,6 @@ def add_request(
seq_len = prompt["prompt_embeds"].shape[0]
prompt["prompt_token_ids"] = [0] * seq_len

if self.tokenizer is not None:
self._validate_token_prompt(
prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request))

processed_inputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
Expand All @@ -782,27 +777,6 @@ def add_request(
priority=priority,
)

def _validate_token_prompt(self, prompt: PromptType,
tokenizer: AnyTokenizer):
# Guard against out-of-vocab tokens.
# For some tokenizers, tokenizer.decode will happily return empty text
# for token ids that are out of vocab, and we don't detect token ids
# that are greater than the max token id before running the model.
# However, these token ids will later crash a cuda kernel at runtime
# with an index out of bounds error. This will crash the entire engine.
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0:
# Empty prompt check is handled later
return
max_input_id = max(prompt_ids)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))

Comment on lines -785 to -805
Copy link
Member Author

@DarkLight1337 DarkLight1337 May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#11980 lets us move this to after processor is applied, making V0 validation (_validate_model_inputs) same as V1.

def _create_sequence_group_with_sampling(
self,
request_id: str,
Expand Down Expand Up @@ -2049,6 +2023,10 @@ def _validate_model_input(
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")

max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id: # type: ignore
raise ValueError(f"Token id {max_input_id} is out of vocabulary")

max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:
Expand Down
3 changes: 3 additions & 0 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ async def beam_search(
else:
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)

if processed_inputs["type"] == "embeds":
raise NotImplementedError

prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data")
Expand Down
6 changes: 4 additions & 2 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_validate_score_input_lens)
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
Expand Down Expand Up @@ -567,10 +567,12 @@ def create_tokens_prompt_from_beam(
mm_kwargs["mm_processor_kwargs"] = prompt[
"mm_processor_kwargs"]

if is_token_prompt(prompt):
if "prompt_token_ids" in prompt:
prompt = cast(TokensPrompt, prompt) # Needed for mypy
prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])

instances.append(
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))

Expand Down
23 changes: 18 additions & 5 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ class EmbedsPrompt(TypedDict):
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""

cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""

Comment on lines +73 to +77
Copy link
Member Author

@DarkLight1337 DarkLight1337 May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although we don't support embedding inputs in V1 yet, I added this so we don't forget to assign cache_salt during processing when we do implement it


SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
Expand Down Expand Up @@ -195,13 +200,21 @@ class EmbedsInputs(TypedDict):
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""

cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""


def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs:
def embeds_inputs(
prompt_embeds: torch.Tensor,
cache_salt: Optional[str] = None,
) -> EmbedsInputs:
"""Construct :class:`EmbedsInputs` from optional values."""
inputs = EmbedsInputs(
type="embeds",
prompt_embeds=prompt_embeds,
)
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)

if cache_salt is not None:
inputs["cache_salt"] = cache_salt

return inputs

Expand Down
27 changes: 8 additions & 19 deletions vllm/inputs/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from vllm.utils import is_list_of

from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokensPrompt)
from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
TokensPrompt)


class ParsedText(TypedDict):
Expand Down Expand Up @@ -90,6 +90,10 @@ class ParsedEmbedsPrompt(TypedDict):
content: EmbedsPrompt


ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
ParsedTokensPrompt, ParsedEmbedsPrompt]


@overload
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
...
Expand All @@ -110,10 +114,7 @@ def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
...


def parse_singleton_prompt(
prompt: SingletonPrompt,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
ParsedEmbedsPrompt]:
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict):
Expand All @@ -131,23 +132,11 @@ def parse_singleton_prompt(
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")


def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt


def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]:
return isinstance(prompt, dict) and "prompt_embeds" in prompt


def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt


def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]:
return isinstance(inputs, dict) and inputs["type"] == "embeds"


def split_enc_dec_inputs(
inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
Expand Down
Loading