Skip to content

Feature/vllm/input embedding completion api #17590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 80 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
cef6894
(vllm) add input embedding
Jan 2, 2025
c51d8fb
improve embedding input
Bryce1010 Jan 6, 2025
9564b40
(vllm) fix import error
Bryce1010 Mar 6, 2025
c60298a
(vllm) fix pre commit error
Bryce1010 Mar 6, 2025
0c24a82
apply ruff and isort fixes
qthequartermasterman Mar 25, 2025
403a165
apply ruff and isort fixes
qthequartermasterman Mar 25, 2025
b1ac072
styling
qthequartermasterman Mar 25, 2025
0390c33
fix missing imports from rebase
qthequartermasterman Mar 25, 2025
0ca4dae
typing fixes
qthequartermasterman Mar 25, 2025
35320fe
type fix
qthequartermasterman Mar 25, 2025
0a77630
type fix
qthequartermasterman Mar 25, 2025
11b6c02
remove unnecessary changes
qthequartermasterman Mar 25, 2025
cb92a3d
remove unnecessary changes
qthequartermasterman Mar 25, 2025
375bd5b
re-add deleted whitespace
qthequartermasterman Mar 25, 2025
c9d8024
Include unit tests from #6869.
qthequartermasterman Mar 25, 2025
a64e627
remove unrelated qwen2 changes
qthequartermasterman Mar 26, 2025
6ab349e
guard clause around fully consumed prompt embeds to avoid returning e…
qthequartermasterman Mar 27, 2025
26c8784
use v0 for prompt embeds model runner tests
qthequartermasterman Mar 27, 2025
b71a13c
fix batching of input embeddings
qthequartermasterman Apr 2, 2025
4aa9ade
style formatting
qthequartermasterman Apr 2, 2025
e2c4c26
remove incorrect overload
qthequartermasterman Apr 3, 2025
26d108a
remove incorrect overload
qthequartermasterman Apr 3, 2025
af20435
Update representations
qthequartermasterman Apr 4, 2025
25aaf3f
remove unrelated changes to docs
qthequartermasterman Apr 4, 2025
bc05860
remove unrelated typing change
qthequartermasterman Apr 4, 2025
b55800d
fix missing syntax
qthequartermasterman Apr 4, 2025
be42a17
do not schedule prompt embeds and non-prompt embeds in the same batch
qthequartermasterman Apr 4, 2025
c8fcfe4
fix style linelength
qthequartermasterman Apr 4, 2025
b21688f
Merge branch 'main' into feature/vllm/add-input-embedding
qthequartermasterman Apr 7, 2025
1e359ae
propogate embeddings for sampled output tokens for decoding
qthequartermasterman Apr 11, 2025
59fbe70
fix type check
qthequartermasterman Apr 11, 2025
c152a3a
do not schedule decode sequence groups with batches containing both p…
qthequartermasterman Apr 11, 2025
42ad800
Merge branch 'main' into feature/vllm/add-input-embedding
qthequartermasterman Apr 11, 2025
e7ab2a2
fix type check
qthequartermasterman Apr 11, 2025
911adbe
add default value to optional parameter
qthequartermasterman Apr 11, 2025
82d923d
remove unused comments
qthequartermasterman Apr 14, 2025
c951479
properly pass in placeholder token ids when testing prompt embeds
qthequartermasterman Apr 15, 2025
01e1a6e
do not test mixed token_ids/prompt_embeds batches in the model_runner
qthequartermasterman Apr 15, 2025
193ad5c
refactor cuda_prepare_decode test
qthequartermasterman Apr 15, 2025
74bd9f4
use correct expected input embeds length for prepare_decode_cuda_grap…
qthequartermasterman Apr 15, 2025
d949f1b
add scheduler test to ensure prompt embeds and prompt tokens are not …
qthequartermasterman Apr 15, 2025
62bbc88
support inputs_embeds in compiled mode
qthequartermasterman Apr 16, 2025
1d1ae4b
fix typing in test
qthequartermasterman Apr 16, 2025
1914676
use corrector operator precedence for handling empty strings
qthequartermasterman Apr 16, 2025
70198f6
only test decoder models with input embeds in v0 backend
qthequartermasterman Apr 16, 2025
934ceae
Merge branch 'vllm-project:main' into feature/vllm/add-input-embedding
qthequartermasterman Apr 16, 2025
5595b45
adjust type hints for modelinputforgpubuilder.build
qthequartermasterman Apr 18, 2025
3343d3e
simplify conditional logic
qthequartermasterman Apr 18, 2025
5010ea0
simplify compilation conditional logic
qthequartermasterman Apr 18, 2025
2075e53
refactor decoder only language model tests to reduce number of times …
qthequartermasterman Apr 18, 2025
9a4fb3c
break up multiple assignments for readability
qthequartermasterman Apr 18, 2025
8ad4091
update type hints in scheduler
qthequartermasterman Apr 18, 2025
9055daf
clear existing lists instead of instantiating new ones
qthequartermasterman Apr 18, 2025
9a57aca
preprocess tensors to handle batched/misshaped prompt embeds to avoid…
qthequartermasterman Apr 18, 2025
bbfb0f0
use seperate Embedsprompt class for preprocessing inputs embeddings
qthequartermasterman Apr 18, 2025
933e567
fix typing
qthequartermasterman Apr 18, 2025
4e0d12f
fix type errors
qthequartermasterman Apr 19, 2025
164aeb5
Merge branch 'vllm-project:main' into feature/vllm/add-input-embedding
qthequartermasterman Apr 19, 2025
9e6909e
fix mistaken type change
qthequartermasterman Apr 19, 2025
90b950a
add missing type hint
qthequartermasterman Apr 19, 2025
01d83f4
add spaces for style
qthequartermasterman Apr 20, 2025
6985452
seperate EmbedsInputs from TokenInputs and embeds_inputs from token_i…
qthequartermasterman Apr 20, 2025
e916551
fix docstrings for EmbedsInputs
qthequartermasterman Apr 20, 2025
69f8725
fix typing for token_type_ids
qthequartermasterman Apr 20, 2025
9c2c89f
fix typing for embeds_tokens in InputRegistry and InputsAdapter
qthequartermasterman Apr 20, 2025
499dc6a
remove prompts and prompt_token_ids from EmbedsPrompts
qthequartermasterman Apr 21, 2025
20668ca
Merge branch 'main' into feature/vllm/add-input-embedding
qthequartermasterman Apr 28, 2025
6712ba6
fight mypy to get correct typing for not embeds prompts
qthequartermasterman Apr 28, 2025
740b290
remove incorrect call to embeds_inputs
qthequartermasterman Apr 28, 2025
8f9bd51
wrestle with mypy and typeddict type narrowing
qthequartermasterman Apr 29, 2025
b8d36c6
wrestle with mypy and typeddict type narrowing
qthequartermasterman Apr 29, 2025
b764c19
support indexing graph runners that with inputs_embeds
qthequartermasterman Apr 29, 2025
0e75db4
feat: completions using embeddings
Nan2018 Oct 28, 2024
cb6ff22
Merge branch 'main' into feature/vllm/add-input-embedding
qthequartermasterman May 1, 2025
85642d0
support encoder decoder models with inputs_embeds
qthequartermasterman May 1, 2025
b226fd6
simplify redundant ternary statement
qthequartermasterman May 1, 2025
b738d3f
explicitly remove support for inputs embeds with speculative decoding…
qthequartermasterman May 1, 2025
2340119
fix occasional device mismatch errors when appending output tokens to…
qthequartermasterman May 1, 2025
6a3173a
Merge remote-tracking branch 'andrew/feature/vllm/add-input-embedding…
Nan2018 May 1, 2025
06215c0
Merge remote-tracking branch 'nan/main' into feature/vllm/input-embed…
Nan2018 May 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 148 additions & 1 deletion tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

# imports for guided decoding tests
import base64
import io
import json
import re
import shutil
Expand All @@ -11,10 +13,11 @@
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import torch
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from vllm.transformers_utils.tokenizer import get_tokenizer

Expand All @@ -31,6 +34,7 @@
PA_NUM_VIRTUAL_TOKENS = 8

GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -107,6 +111,14 @@ async def client(server):
yield async_client


def create_dummy_embeds(num_tokens: int = 5) -> str:
"""Create dummy embeddings and return them as base64 encoded string."""
dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size)
buffer = io.BytesIO()
torch.save(dummy_embeds, buffer)
return base64.b64encode(buffer.getvalue()).decode('utf-8')


@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
Expand Down Expand Up @@ -143,6 +155,45 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
assert len(completion.choices[0].text) >= 1
assert completion.choices[0].prompt_logprobs is None

# test using prompt_embeds
encoded_embeds = create_dummy_embeds()
completion = await client.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds})
assert len(completion.choices[0].text) >= 1
assert completion.choices[0].prompt_logprobs is None

# test batch completion with prompt_embeds
encoded_embeds2 = create_dummy_embeds()
completion = await client.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
assert len(completion.choices) == 2
assert len(completion.choices[0].text) >= 1
assert len(completion.choices[1].text) >= 1

# test error case: neither prompt nor prompt_embeds provided
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
max_tokens=5,
temperature=0.0,
)

# test error case: invalid prompt_embeds
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": "invalid_base64"})


@pytest.mark.asyncio
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
Expand Down Expand Up @@ -343,6 +394,55 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert chunk.choices[0].text
assert "".join(chunks) == single_output

# test streaming with prompt_embeds
encoded_embeds = create_dummy_embeds()
single_completion = await client.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds})
single_output = single_completion.choices[0].text

stream = await client.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
stream=True,
extra_body={"prompt_embeds": encoded_embeds})
chunks = []
finish_reason_count = 0
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert "".join(chunks) == single_output

# test batch streaming with prompt_embeds
encoded_embeds2 = create_dummy_embeds()
stream = await client.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
stream=True,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
chunks = [[], []]
finish_reason_count = 0
async for chunk in stream:
chunks[chunk.choices[0].index].append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == 2
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert len(chunks[0]) > 0
assert len(chunks[1]) > 0


@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down Expand Up @@ -760,6 +860,53 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) > 5

# test using prompt_embeds
encoded_embeds = create_dummy_embeds()
completion = await client.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=logprobs_arg,
extra_body={"prompt_embeds": encoded_embeds})

logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) > 5
assert (len(logprobs.token_logprobs) > 5
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) > 5

# test batch completion with prompt_embeds
encoded_embeds2 = create_dummy_embeds()
completion = await client.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=logprobs_arg,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})

assert len(completion.choices) == 2
for choice in completion.choices:
logprobs = choice.logprobs
assert logprobs is not None
assert len(logprobs.text_offset) > 5
assert (len(logprobs.token_logprobs) > 5
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg,
1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) > 5


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
Expand Down
9 changes: 7 additions & 2 deletions vllm/entrypoints/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Optional, Union

import torch

from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
Expand All @@ -23,6 +25,7 @@ def log_inputs(
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[list[int]],
prompt_embeds: Optional[torch.Tensor],
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
Expand All @@ -39,6 +42,8 @@ def log_inputs(
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"prompt_embeds shape: %s, "
"lora_request: %s, prompt_adapter_request: %s.", request_id,
prompt, params, prompt_token_ids, lora_request,
prompt_adapter_request)
prompt, params, prompt_token_ids,
prompt_embeds.shape if prompt_embeds is not None else None,
lora_request, prompt_adapter_request)
13 changes: 11 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,8 +754,9 @@ def check_cache_salt_support(cls, data):
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: Optional[str] = None
prompt: Union[list[int], list[list[int]], str, list[str]]
model: str
prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
Expand Down Expand Up @@ -1029,6 +1030,14 @@ def validate_stream_options(cls, data):

return data

@model_validator(mode="before")
@classmethod
def validate_prompt_and_prompt_embeds(cls, data):
if data.get("prompt") is None and data.get("prompt_embeds") is None:
raise ValueError(
"At least one of `prompt` or `prompt_embeds` must be set.")
return data


class EmbeddingCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
Expand Down
6 changes: 4 additions & 2 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
request_prompts, engine_prompts = await self._preprocess_completion(
request,
tokenizer,
request.prompt,

Check failure on line 111 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 3 to "_preprocess_completion" of "OpenAIServing" has incompatible type "Union[list[int], list[list[int]], str, list[str], None]"; expected "Union[str, list[str], list[int], list[list[int]]]" [arg-type]
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
Expand All @@ -131,7 +131,9 @@
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
engine_prompt.get("prompt_token_ids", [])

Check failure on line 134 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "len" has incompatible type "object"; expected "Sized" [arg-type]
or engine_prompt.get("prompt_embeds", []))

if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens, self.default_sampling_params)
Expand All @@ -154,13 +156,13 @@

if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt=engine_prompt,

Check failure on line 159 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "prompt" to "beam_search" of "EngineClient" has incompatible type "TypedDict({})"; expected "Union[Union[str, TextPrompt, TokensPrompt, EmbedsPrompt], ExplicitEncoderDecoderPrompt[Union[str, TextPrompt, TokensPrompt, EmbedsPrompt], Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]]]" [arg-type]
request_id=request_id,
params=sampling_params,
)
else:
generator = self.engine_client.generate(
engine_prompt,

Check failure on line 165 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 1 to "generate" of "EngineClient" has incompatible type "TypedDict({})"; expected "Union[Union[str, TextPrompt, TokensPrompt, EmbedsPrompt], ExplicitEncoderDecoderPrompt[Union[str, TextPrompt, TokensPrompt, EmbedsPrompt], Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]]]" [arg-type]
sampling_params,
request_id_item,
lora_request=lora_request,
Expand Down Expand Up @@ -211,7 +213,7 @@
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
final_res.prompt = request_prompts[i]["prompt"]
final_res.prompt = request_prompts[i].get("prompt")

Check failure on line 216 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "object", variable has type "Optional[str]") [assignment]

final_res_batch_checked = cast(list[RequestOutput],
final_res_batch)
Expand Down
Loading
Loading