Skip to content

Commit 38d780f

Browse files
simon-moAlvant
authored andcommitted
Revert "rename PromptInputs and inputs with backward compatibility (vllm-project#8760) (vllm-project#8810)
1 parent 1884801 commit 38d780f

File tree

21 files changed

+245
-438
lines changed

21 files changed

+245
-438
lines changed

benchmarks/benchmark_latency.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from vllm import LLM, SamplingParams
1313
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
14-
from vllm.inputs import PromptType
14+
from vllm.inputs import PromptInputs
1515
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1616
from vllm.utils import FlexibleArgumentParser
1717

@@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
6161
dummy_prompt_token_ids = np.random.randint(10000,
6262
size=(args.batch_size,
6363
args.input_len))
64-
dummy_prompts: List[PromptType] = [{
64+
dummy_inputs: List[PromptInputs] = [{
6565
"prompt_token_ids": batch
6666
} for batch in dummy_prompt_token_ids.tolist()]
6767

@@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
7474
],
7575
on_trace_ready=torch.profiler.tensorboard_trace_handler(
7676
str(profile_dir))) as p:
77-
llm.generate(dummy_prompts,
77+
llm.generate(dummy_inputs,
7878
sampling_params=sampling_params,
7979
use_tqdm=False)
8080
print(p.key_averages())
8181
else:
8282
start_time = time.perf_counter()
83-
llm.generate(dummy_prompts,
83+
llm.generate(dummy_inputs,
8484
sampling_params=sampling_params,
8585
use_tqdm=False)
8686
end_time = time.perf_counter()

docs/source/dev/multimodal/multimodal_index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Multi-Modality
88
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
99

1010
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
11-
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.
11+
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
1212

1313
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
1414
by following :ref:`this guide <adding_multimodal_plugin>`.

docs/source/dev/offline_inference/llm_inputs.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
LLM Inputs
22
==========
33

4-
.. autodata:: vllm.inputs.PromptType
4+
.. autodata:: vllm.inputs.PromptInputs
55

66
.. autoclass:: vllm.inputs.TextPrompt
77
:show-inheritance:

docs/source/models/vlm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
2727
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
2828
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
2929

30-
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:
30+
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
3131

3232
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
3333
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

tests/async_engine/test_async_llm_engine.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,17 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
8686

8787
@pytest.mark.asyncio
8888
async def test_new_requests_event():
89-
params = SamplingParams()
90-
9189
engine = MockAsyncLLMEngine()
9290
engine.start_background_loop()
9391
await asyncio.sleep(0.01)
9492
assert engine.engine.step_calls == 0
9593

96-
await engine.add_request("1", "", params)
94+
await engine.add_request("1", "", None)
9795
await asyncio.sleep(0.01)
9896
assert engine.engine.add_request_calls == 1
9997
assert engine.engine.step_calls == 1
10098

101-
await engine.add_request("2", "", params)
99+
await engine.add_request("2", "", None)
102100
engine.engine.generate("2")
103101
await asyncio.sleep(0)
104102
await asyncio.sleep(0)
@@ -113,7 +111,7 @@ async def test_new_requests_event():
113111
await asyncio.sleep(0.001)
114112
assert engine.engine.step_calls == old_step_calls
115113

116-
await engine.add_request("3", "", params)
114+
await engine.add_request("3", "", None)
117115
await asyncio.sleep(0.01)
118116
assert engine.engine.add_request_calls == 3
119117
assert engine.engine.step_calls == old_step_calls + 1

tests/entrypoints/llm/test_encode.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,21 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
4949
assert [o.outputs for o in o1] == [o.outputs for o in o2]
5050

5151

52+
@pytest.mark.skip_global_cleanup
53+
@pytest.mark.parametrize('prompt', PROMPTS)
54+
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
55+
pooling_params = PoolingParams()
56+
57+
with pytest.warns(DeprecationWarning, match="'prompts'"):
58+
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)
59+
60+
v2_output = llm.encode(prompt, pooling_params=pooling_params)
61+
assert_outputs_equal(v1_output, v2_output)
62+
63+
v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
64+
assert_outputs_equal(v1_output, v2_output)
65+
66+
5267
@pytest.mark.skip_global_cleanup
5368
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
5469
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
@@ -64,6 +79,25 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
6479
assert_outputs_equal(v1_output, v2_output)
6580

6681

82+
@pytest.mark.skip_global_cleanup
83+
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
84+
pooling_params = PoolingParams()
85+
86+
with pytest.warns(DeprecationWarning, match="'prompts'"):
87+
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)
88+
89+
v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
90+
assert_outputs_equal(v1_output, v2_output)
91+
92+
v2_output = llm.encode(
93+
[{
94+
"prompt": p
95+
} for p in PROMPTS],
96+
pooling_params=pooling_params,
97+
)
98+
assert_outputs_equal(v1_output, v2_output)
99+
100+
67101
@pytest.mark.skip_global_cleanup
68102
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
69103
pooling_params = PoolingParams()

tests/entrypoints/llm/test_generate.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,23 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
4747
assert [o.outputs for o in o1] == [o.outputs for o in o2]
4848

4949

50+
@pytest.mark.skip_global_cleanup
51+
@pytest.mark.parametrize('prompt', PROMPTS)
52+
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
53+
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
54+
55+
with pytest.warns(DeprecationWarning, match="'prompts'"):
56+
v1_output = llm.generate(prompts=prompt,
57+
sampling_params=sampling_params)
58+
59+
v2_output = llm.generate(prompt, sampling_params=sampling_params)
60+
assert_outputs_equal(v1_output, v2_output)
61+
62+
v2_output = llm.generate({"prompt": prompt},
63+
sampling_params=sampling_params)
64+
assert_outputs_equal(v1_output, v2_output)
65+
66+
5067
@pytest.mark.skip_global_cleanup
5168
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
5269
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
@@ -62,6 +79,26 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
6279
assert_outputs_equal(v1_output, v2_output)
6380

6481

82+
@pytest.mark.skip_global_cleanup
83+
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
84+
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
85+
86+
with pytest.warns(DeprecationWarning, match="'prompts'"):
87+
v1_output = llm.generate(prompts=PROMPTS,
88+
sampling_params=sampling_params)
89+
90+
v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
91+
assert_outputs_equal(v1_output, v2_output)
92+
93+
v2_output = llm.generate(
94+
[{
95+
"prompt": p
96+
} for p in PROMPTS],
97+
sampling_params=sampling_params,
98+
)
99+
assert_outputs_equal(v1_output, v2_output)
100+
101+
65102
@pytest.mark.skip_global_cleanup
66103
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
67104
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

tests/mq_llm_engine/test_error_handling.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ async def test_evil_forward(tmp_socket):
6161

6262
# Throws an error in first forward pass.
6363
with pytest.raises(RAISED_ERROR):
64-
async for _ in client.generate(prompt="Hello my name is",
64+
async for _ in client.generate(inputs="Hello my name is",
6565
sampling_params=SamplingParams(),
6666
request_id=uuid.uuid4()):
6767
pass
6868
assert client.errored
6969

7070
# Engine is errored, should get ENGINE_DEAD_ERROR.
7171
with pytest.raises(MQEngineDeadError):
72-
async for _ in client.generate(prompt="Hello my name is",
72+
async for _ in client.generate(inputs="Hello my name is",
7373
sampling_params=SamplingParams(),
7474
request_id=uuid.uuid4()):
7575
pass
@@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
118118

119119
# Generate call should throw ENGINE_DEAD_ERROR
120120
with pytest.raises(MQEngineDeadError):
121-
async for _ in client.generate(prompt="Hello my name is",
121+
async for _ in client.generate(inputs="Hello my name is",
122122
sampling_params=SamplingParams(),
123123
request_id=uuid.uuid4()):
124124
pass
@@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket):
160160
# with reference to the original KeyError("foo")
161161
with pytest.raises(MQEngineDeadError) as execinfo:
162162
async for _ in client.generate(
163-
prompt="Hello my name is",
163+
inputs="Hello my name is",
164164
sampling_params=SamplingParams(max_tokens=10),
165165
request_id=uuid.uuid4()):
166166
pass
@@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket):
183183

184184
# Invalid request should fail, but not crash the server.
185185
with pytest.raises(ValueError):
186-
async for _ in client.generate(prompt="Hello my name is",
186+
async for _ in client.generate(inputs="Hello my name is",
187187
sampling_params=SamplingParams(),
188188
request_id="abcd-1",
189189
lora_request=LoRARequest(
@@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket):
192192
pass
193193

194194
# This request should be okay.
195-
async for _ in client.generate(prompt="Hello my name is",
195+
async for _ in client.generate(inputs="Hello my name is",
196196
sampling_params=SamplingParams(),
197197
request_id="abcd-2"):
198198
pass

tests/mq_llm_engine/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ async def generate(
2020
count = 0
2121
async for out in client.generate(
2222
request_id=request_id,
23-
prompt="Hello my name is Robert and",
23+
inputs="Hello my name is Robert and",
2424
sampling_params=SamplingParams(max_tokens=num_tokens,
2525
temperature=0)):
2626

vllm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from vllm.engine.llm_engine import LLMEngine
66
from vllm.entrypoints.llm import LLM
77
from vllm.executor.ray_utils import initialize_ray_cluster
8-
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
8+
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
99
from vllm.model_executor.models import ModelRegistry
1010
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
1111
EmbeddingRequestOutput, RequestOutput)
@@ -19,7 +19,7 @@
1919
"__version_tuple__",
2020
"LLM",
2121
"ModelRegistry",
22-
"PromptType",
22+
"PromptInputs",
2323
"TextPrompt",
2424
"TokensPrompt",
2525
"SamplingParams",

0 commit comments

Comments
 (0)