Skip to content

Commit cc2a77d

Browse files
qthequartermasterman临景Bryce1010Nan2018DarkLight1337
authoredMay 2, 2025
[Core] [Bugfix] Add Input Embeddings (#15428)
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: 临景 <linjing.yx@alibaba-inc.com> Co-authored-by: Bryce1010 <bryceyx@gmail.com> Co-authored-by: Nan2018 <nan@protopia.ai> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 9e2de9b commit cc2a77d

File tree

22 files changed

+691
-113
lines changed

22 files changed

+691
-113
lines changed
 

‎tests/conftest.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ def __init__(
787787

788788
def get_inputs(
789789
self,
790-
prompts: list[str],
790+
prompts: Union[list[str], list[torch.Tensor]],
791791
images: Optional[PromptImageInput] = None,
792792
videos: Optional[PromptVideoInput] = None,
793793
audios: Optional[PromptAudioInput] = None,
@@ -809,16 +809,18 @@ def get_inputs(
809809
if audios is not None and (audio := audios[i]) is not None:
810810
multi_modal_data["audio"] = audio
811811

812-
inputs.append(
813-
TextPrompt(prompt=prompt,
814-
multi_modal_data=multi_modal_data
815-
if multi_modal_data else None))
812+
text_prompt_kwargs = {
813+
("prompt" if isinstance(prompt, str) else "prompt_embeds"):
814+
prompt,
815+
"multi_modal_data": multi_modal_data or None
816+
}
817+
inputs.append(TextPrompt(**text_prompt_kwargs))
816818

817819
return inputs
818820

819821
def generate(
820822
self,
821-
prompts: list[str],
823+
prompts: Union[list[str], list[torch.Tensor]],
822824
sampling_params: SamplingParams,
823825
images: Optional[PromptImageInput] = None,
824826
videos: Optional[PromptVideoInput] = None,
@@ -844,7 +846,7 @@ def generate(
844846
output_str = sample.text
845847
output_ids = list(sample.token_ids)
846848
req_sample_output_ids.append(prompt_ids + output_ids)
847-
req_sample_output_strs.append(prompt_str + output_str)
849+
req_sample_output_strs.append((prompt_str or "") + output_str)
848850
outputs.append((req_sample_output_ids, req_sample_output_strs))
849851
return outputs
850852

@@ -911,7 +913,7 @@ def generate_encoder_decoder_w_logprobs(
911913

912914
def generate_greedy(
913915
self,
914-
prompts: list[str],
916+
prompts: Union[list[str], list[torch.Tensor]],
915917
max_tokens: int,
916918
images: Optional[PromptImageInput] = None,
917919
videos: Optional[PromptVideoInput] = None,

‎tests/core/test_scheduler.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22

33
import time
44
from collections import deque
5+
from typing import Optional
56
from unittest.mock import MagicMock
67

78
import pytest # noqa
9+
import torch
810
from torch import Use # noqa
911

1012
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
1113
from vllm.core.interfaces import AllocStatus
1214
from vllm.core.scheduler import Scheduler, SchedulingBudget
1315
from vllm.lora.request import LoRARequest
14-
from vllm.sequence import SequenceGroup
16+
from vllm.sequence import SequenceGroup, SequenceStatus
1517

1618
from .utils import (append_new_token, append_new_token_seq,
1719
append_new_token_seq_group, create_dummy_prompt,
@@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
968970
), "A partial prefix of C (4 tokens) should be prefilled, with the "
969971
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
970972
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
973+
974+
975+
def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
976+
"""
977+
Test that the scheduler does not schedule batches with prompt tokens and
978+
prompt embeddings co-mingled.
979+
"""
980+
block_size = 2
981+
max_seq_group = 3
982+
scheduler = initialize_scheduler(
983+
block_size=block_size,
984+
num_cpu_blocks=16,
985+
num_gpu_blocks=16,
986+
max_num_seqs=max_seq_group,
987+
max_model_len=100,
988+
enable_prefix_caching=True,
989+
)
990+
991+
# the odd indexed inputs should be passed in via embeddings,
992+
# evens via token_ids
993+
seq_length = 7
994+
embedding_size = 5
995+
num_seqs = 11
996+
seq_tokens: list[list[int]] = []
997+
seq_embeds: list[Optional[torch.Tensor]] = []
998+
for i in range(num_seqs):
999+
if i % 2:
1000+
seq_tokens.append(list(range(seq_length)))
1001+
seq_embeds.append(None)
1002+
else:
1003+
seq_tokens.append([0] * seq_length)
1004+
seq_embeds.append(torch.rand(embedding_size))
1005+
1006+
seq_and_seq_groups = [
1007+
create_dummy_prompt(f"{i}",
1008+
prompt_tokens=seq_tokens[i],
1009+
prompt_embeds=seq_embeds[i],
1010+
block_size=block_size)
1011+
for i in range(len(seq_tokens))
1012+
]
1013+
1014+
for _, seq_group in seq_and_seq_groups:
1015+
scheduler.add_seq_group(seq_group)
1016+
1017+
while not all(seq.is_finished() for seq, _ in seq_and_seq_groups):
1018+
unfinished_seq_groups = [
1019+
seq_group for _, seq_group in seq_and_seq_groups
1020+
if not seq_group.is_finished()
1021+
]
1022+
_, out = schedule_and_update_computed_tokens(scheduler)
1023+
assert len(out.scheduled_seq_groups) > 0
1024+
batch_is_prompt_embeds = out.scheduled_seq_groups[
1025+
0].seq_group.uses_prompt_embeds()
1026+
expected_scheduled_seq_groups = [
1027+
seq_group for seq_group in unfinished_seq_groups
1028+
if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds
1029+
]
1030+
1031+
# We should have as many scheduled groups as possible, without mixing
1032+
assert len(out.scheduled_seq_groups) == min(
1033+
max_seq_group, len(expected_scheduled_seq_groups))
1034+
assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() ==
1035+
batch_is_prompt_embeds
1036+
for scheduled_seq_group in out.scheduled_seq_groups)
1037+
1038+
# Finish the scheduled groups
1039+
for scheduled_seq_group in out.scheduled_seq_groups:
1040+
for seq in scheduled_seq_group.seq_group.seqs:
1041+
seq.status = SequenceStatus.FINISHED_STOPPED
1042+
scheduler.free_finished_seq_groups()

‎tests/core/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from collections.abc import Sequence as GenericSequence
66
from typing import Any, Optional
77

8+
import torch
9+
810
from vllm import SamplingParams
911
from vllm.core.scheduler import Scheduler, SchedulerOutputs
10-
from vllm.inputs import EncoderDecoderInputs, token_inputs
12+
from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs
1113
from vllm.lora.request import LoRARequest
1214
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
1315
SequenceGroupMetadata)
@@ -19,6 +21,7 @@ def create_dummy_prompt(
1921
block_size: Optional[int] = None,
2022
lora_request: Optional[LoRARequest] = None,
2123
prompt_tokens: Optional[list[int]] = None,
24+
prompt_embeds: Optional[torch.Tensor] = None,
2225
min_tokens: int = 0,
2326
max_tokens: int = 16,
2427
) -> tuple[Sequence, SequenceGroup]:
@@ -31,9 +34,13 @@ def create_dummy_prompt(
3134
prompt_tokens = list(range(prompt_length))
3235

3336
prompt_str = " ".join([str(t) for t in prompt_tokens])
37+
inputs = token_inputs(
38+
prompt_token_ids=prompt_tokens,
39+
prompt=prompt_str) if prompt_embeds is None else embeds_inputs(
40+
prompt_embeds=prompt_embeds)
3441
prompt = Sequence(
3542
int(request_id),
36-
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
43+
inputs=inputs,
3744
block_size=block_size,
3845
)
3946
seq_group = SequenceGroup(

‎tests/models/language/generation/test_common.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
from typing import Optional
4+
25
import pytest
36
import torch
47

@@ -110,6 +113,18 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
110113
hf_outputs = hf_model.generate_greedy_logprobs_limit(
111114
example_prompts, max_tokens, num_logprobs)
112115

116+
prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv(
117+
"VLLM_USE_V1") == "0" else None
118+
prompt_token_ids = []
119+
for prompt in example_prompts:
120+
token_ids = hf_model.tokenizer(prompt,
121+
return_tensors="pt").input_ids.to(
122+
hf_model.model.device)
123+
prompt_token_ids.append(token_ids)
124+
if prompt_embeds is not None:
125+
prompt_embeds.append(hf_model.model.get_input_embeddings()(
126+
token_ids).squeeze(0))
127+
113128
with vllm_runner(
114129
model,
115130
tokenizer_name=model_info.tokenizer or model,
@@ -119,13 +134,24 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
119134
) as vllm_model:
120135
vllm_outputs = vllm_model.generate_greedy_logprobs(
121136
example_prompts, max_tokens, num_logprobs)
137+
if prompt_embeds is not None:
138+
vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs(
139+
prompt_embeds, max_tokens, num_logprobs)
122140

123141
check_logprobs_close(
124142
outputs_0_lst=hf_outputs,
125143
outputs_1_lst=vllm_outputs,
126144
name_0="hf",
127145
name_1="vllm",
128146
)
147+
if prompt_embeds is not None:
148+
check_logprobs_close(
149+
outputs_0_lst=vllm_outputs,
150+
outputs_1_lst=vllm_outputs_from_embeds,
151+
name_0="vllm",
152+
name_1="vllm_from_embeds",
153+
)
154+
129155
if use_rocm_aiter:
130156
# this is to ensure that vllm engine
131157
# has deallocated the memory before running the next

0 commit comments

Comments
 (0)
Failed to load comments.