Skip to content

Commit e729f1d

Browse files
committed
Add backwards compatibility for #8673
1 parent 6481cf3 commit e729f1d

File tree

8 files changed

+256
-88
lines changed

8 files changed

+256
-88
lines changed

tests/async_engine/test_async_llm_engine.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
8686

8787
@pytest.mark.asyncio
8888
async def test_new_requests_event():
89+
params = SamplingParams()
90+
8991
engine = MockAsyncLLMEngine()
9092
engine.start_background_loop()
9193
await asyncio.sleep(0.01)
9294
assert engine.engine.step_calls == 0
9395

94-
await engine.add_request("1", "", None)
96+
await engine.add_request("1", "", params)
9597
await asyncio.sleep(0.01)
9698
assert engine.engine.add_request_calls == 1
9799
assert engine.engine.step_calls == 1
98100

99-
await engine.add_request("2", "", None)
101+
await engine.add_request("2", "", params)
100102
engine.engine.generate("2")
101103
await asyncio.sleep(0)
102104
await asyncio.sleep(0)
@@ -111,7 +113,7 @@ async def test_new_requests_event():
111113
await asyncio.sleep(0.001)
112114
assert engine.engine.step_calls == old_step_calls
113115

114-
await engine.add_request("3", "", None)
116+
await engine.add_request("3", "", params)
115117
await asyncio.sleep(0.01)
116118
assert engine.engine.add_request_calls == 3
117119
assert engine.engine.step_calls == old_step_calls + 1

tests/entrypoints/llm/test_encode.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
4949
assert [o.outputs for o in o1] == [o.outputs for o in o2]
5050

5151

52-
@pytest.mark.skip_global_cleanup
53-
@pytest.mark.parametrize('prompt', PROMPTS)
54-
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
55-
pooling_params = PoolingParams()
56-
57-
with pytest.warns(DeprecationWarning, match="'prompts'"):
58-
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)
59-
60-
v2_output = llm.encode(prompt, pooling_params=pooling_params)
61-
assert_outputs_equal(v1_output, v2_output)
62-
63-
v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
64-
assert_outputs_equal(v1_output, v2_output)
65-
66-
6752
@pytest.mark.skip_global_cleanup
6853
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
6954
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
@@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
7964
assert_outputs_equal(v1_output, v2_output)
8065

8166

82-
@pytest.mark.skip_global_cleanup
83-
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
84-
pooling_params = PoolingParams()
85-
86-
with pytest.warns(DeprecationWarning, match="'prompts'"):
87-
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)
88-
89-
v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
90-
assert_outputs_equal(v1_output, v2_output)
91-
92-
v2_output = llm.encode(
93-
[{
94-
"prompt": p
95-
} for p in PROMPTS],
96-
pooling_params=pooling_params,
97-
)
98-
assert_outputs_equal(v1_output, v2_output)
99-
100-
10167
@pytest.mark.skip_global_cleanup
10268
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
10369
pooling_params = PoolingParams()

tests/entrypoints/llm/test_generate.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
4747
assert [o.outputs for o in o1] == [o.outputs for o in o2]
4848

4949

50-
@pytest.mark.skip_global_cleanup
51-
@pytest.mark.parametrize('prompt', PROMPTS)
52-
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
53-
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
54-
55-
with pytest.warns(DeprecationWarning, match="'prompts'"):
56-
v1_output = llm.generate(prompts=prompt,
57-
sampling_params=sampling_params)
58-
59-
v2_output = llm.generate(prompt, sampling_params=sampling_params)
60-
assert_outputs_equal(v1_output, v2_output)
61-
62-
v2_output = llm.generate({"prompt": prompt},
63-
sampling_params=sampling_params)
64-
assert_outputs_equal(v1_output, v2_output)
65-
66-
6750
@pytest.mark.skip_global_cleanup
6851
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
6952
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
@@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
7962
assert_outputs_equal(v1_output, v2_output)
8063

8164

82-
@pytest.mark.skip_global_cleanup
83-
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
84-
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
85-
86-
with pytest.warns(DeprecationWarning, match="'prompts'"):
87-
v1_output = llm.generate(prompts=PROMPTS,
88-
sampling_params=sampling_params)
89-
90-
v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
91-
assert_outputs_equal(v1_output, v2_output)
92-
93-
v2_output = llm.generate(
94-
[{
95-
"prompt": p
96-
} for p in PROMPTS],
97-
sampling_params=sampling_params,
98-
)
99-
assert_outputs_equal(v1_output, v2_output)
100-
101-
10265
@pytest.mark.skip_global_cleanup
10366
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
10467
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

vllm/engine/async_llm_engine.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import time
33
import weakref
44
from functools import partial
5-
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
6-
Mapping, Optional, Set, Tuple, Type, Union)
5+
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
6+
List, Mapping, Optional, Set, Tuple, Type, Union, overload)
77
from weakref import ReferenceType
88

99
import vllm.envs as envs
@@ -28,7 +28,7 @@
2828
from vllm.sequence import ExecuteModelRequest
2929
from vllm.transformers_utils.tokenizer import AnyTokenizer
3030
from vllm.usage.usage_lib import UsageContext
31-
from vllm.utils import weak_bind
31+
from vllm.utils import deprecate_kwargs, weak_bind
3232

3333
logger = init_logger(__name__)
3434
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@@ -402,6 +402,21 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
402402
"""Stop the remote worker execution loop."""
403403
await self.model_executor.stop_remote_worker_execution_loop_async()
404404

405+
@overload # DEPRECATED
406+
async def add_request_async(
407+
self,
408+
request_id: str,
409+
*,
410+
inputs: PromptType,
411+
params: Union[SamplingParams, PoolingParams],
412+
arrival_time: Optional[float] = None,
413+
lora_request: Optional[LoRARequest] = None,
414+
trace_headers: Optional[Mapping[str, str]] = None,
415+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
416+
) -> None:
417+
...
418+
419+
@overload
405420
async def add_request_async(
406421
self,
407422
request_id: str,
@@ -411,8 +426,30 @@ async def add_request_async(
411426
lora_request: Optional[LoRARequest] = None,
412427
trace_headers: Optional[Mapping[str, str]] = None,
413428
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
429+
) -> None:
430+
...
431+
432+
@deprecate_kwargs(
433+
"inputs",
434+
additional_message="Please use the 'prompt' parameter instead.",
435+
)
436+
async def add_request_async(
437+
self,
438+
request_id: str,
439+
prompt: Optional[PromptType] = None,
440+
params: Optional[Union[SamplingParams, PoolingParams]] = None,
441+
arrival_time: Optional[float] = None,
442+
lora_request: Optional[LoRARequest] = None,
443+
trace_headers: Optional[Mapping[str, str]] = None,
444+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
445+
*,
446+
inputs: Optional[PromptType] = None, # DEPRECATED
414447
) -> None:
415448
"""Async version of :meth:`add_request`."""
449+
if inputs is not None:
450+
prompt = inputs
451+
assert prompt is not None and params is not None
452+
416453
if lora_request is not None and not self.lora_config:
417454
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
418455
"not enabled!")
@@ -774,16 +811,55 @@ async def run_engine_loop(engine_ref: ReferenceType):
774811

775812
# This method does not need to be async, but kept that way
776813
# for backwards compatibility.
777-
async def add_request(
814+
@overload # DEPRECATED
815+
def add_request(
816+
self,
817+
request_id: str,
818+
*,
819+
inputs: PromptType,
820+
params: Union[SamplingParams, PoolingParams],
821+
arrival_time: Optional[float] = None,
822+
lora_request: Optional[LoRARequest] = None,
823+
trace_headers: Optional[Mapping[str, str]] = None,
824+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
825+
) -> Coroutine[None, None, AsyncGenerator[Union[
826+
RequestOutput, EmbeddingRequestOutput], None]]:
827+
...
828+
829+
@overload
830+
def add_request(
778831
self,
779832
request_id: str,
780833
prompt: PromptType,
781834
params: Union[SamplingParams, PoolingParams],
782835
arrival_time: Optional[float] = None,
783836
lora_request: Optional[LoRARequest] = None,
784837
trace_headers: Optional[Mapping[str, str]] = None,
785-
prompt_adapter_request: Optional[PromptAdapterRequest] = None
838+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
839+
) -> Coroutine[None, None, AsyncGenerator[Union[
840+
RequestOutput, EmbeddingRequestOutput], None]]:
841+
...
842+
843+
@deprecate_kwargs(
844+
"inputs",
845+
additional_message="Please use the 'prompt' parameter instead.",
846+
)
847+
async def add_request(
848+
self,
849+
request_id: str,
850+
prompt: Optional[PromptType] = None,
851+
params: Optional[Union[SamplingParams, PoolingParams]] = None,
852+
arrival_time: Optional[float] = None,
853+
lora_request: Optional[LoRARequest] = None,
854+
trace_headers: Optional[Mapping[str, str]] = None,
855+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
856+
*,
857+
inputs: Optional[PromptType] = None, # DEPRECATED
786858
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
859+
if inputs is not None:
860+
prompt = inputs
861+
assert prompt is not None and params is not None
862+
787863
if not self.is_running:
788864
if self.start_engine_loop:
789865
self.start_background_loop()

vllm/engine/llm_engine.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
77
Iterable, List, Mapping, NamedTuple, Optional)
88
from typing import Sequence as GenericSequence
9-
from typing import Set, Type, Union
9+
from typing import Set, Type, Union, overload
1010

1111
import torch
1212
from typing_extensions import TypeVar
@@ -51,7 +51,7 @@
5151
BaseTokenizerGroup, init_tokenizer_from_configs)
5252
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
5353
usage_message)
54-
from vllm.utils import Counter, Device, weak_bind
54+
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
5555
from vllm.version import __version__ as VLLM_VERSION
5656

5757
logger = init_logger(__name__)
@@ -686,6 +686,21 @@ def _add_processed_request(
686686
def stop_remote_worker_execution_loop(self) -> None:
687687
self.model_executor.stop_remote_worker_execution_loop()
688688

689+
@overload # DEPRECATED
690+
def add_request(
691+
self,
692+
request_id: str,
693+
*,
694+
inputs: PromptType,
695+
params: Union[SamplingParams, PoolingParams],
696+
arrival_time: Optional[float] = None,
697+
lora_request: Optional[LoRARequest] = None,
698+
trace_headers: Optional[Mapping[str, str]] = None,
699+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
700+
) -> None:
701+
...
702+
703+
@overload
689704
def add_request(
690705
self,
691706
request_id: str,
@@ -695,6 +710,24 @@ def add_request(
695710
lora_request: Optional[LoRARequest] = None,
696711
trace_headers: Optional[Mapping[str, str]] = None,
697712
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
713+
) -> None:
714+
...
715+
716+
@deprecate_kwargs(
717+
"inputs",
718+
additional_message="Please use the 'prompt' parameter instead.",
719+
)
720+
def add_request(
721+
self,
722+
request_id: str,
723+
prompt: Optional[PromptType] = None,
724+
params: Optional[Union[SamplingParams, PoolingParams]] = None,
725+
arrival_time: Optional[float] = None,
726+
lora_request: Optional[LoRARequest] = None,
727+
trace_headers: Optional[Mapping[str, str]] = None,
728+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
729+
*,
730+
inputs: Optional[PromptType] = None, # DEPRECATED
698731
) -> None:
699732
"""Add a request to the engine's request pool.
700733
@@ -737,6 +770,10 @@ def add_request(
737770
>>> # continue the request processing
738771
>>> ...
739772
"""
773+
if inputs is not None:
774+
prompt = inputs
775+
assert prompt is not None and params is not None
776+
740777
if lora_request is not None and not self.lora_config:
741778
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
742779
"not enabled!")

0 commit comments

Comments
 (0)