Skip to content

Commit 5b950cc

Browse files
simon-mosumitd2
authored andcommitted
Revert "[Core] Rename PromptInputs to PromptType, and inputs to prompt" (vllm-project#8750)
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent 237cfc1 commit 5b950cc

File tree

18 files changed

+162
-157
lines changed

18 files changed

+162
-157
lines changed

benchmarks/benchmark_latency.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from vllm import LLM, SamplingParams
1313
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
14-
from vllm.inputs import PromptType
14+
from vllm.inputs import PromptInputs
1515
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1616
from vllm.utils import FlexibleArgumentParser
1717

@@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
6161
dummy_prompt_token_ids = np.random.randint(10000,
6262
size=(args.batch_size,
6363
args.input_len))
64-
dummy_prompts: List[PromptType] = [{
64+
dummy_inputs: List[PromptInputs] = [{
6565
"prompt_token_ids": batch
6666
} for batch in dummy_prompt_token_ids.tolist()]
6767

@@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
7474
],
7575
on_trace_ready=torch.profiler.tensorboard_trace_handler(
7676
str(profile_dir))) as p:
77-
llm.generate(dummy_prompts,
77+
llm.generate(dummy_inputs,
7878
sampling_params=sampling_params,
7979
use_tqdm=False)
8080
print(p.key_averages())
8181
else:
8282
start_time = time.perf_counter()
83-
llm.generate(dummy_prompts,
83+
llm.generate(dummy_inputs,
8484
sampling_params=sampling_params,
8585
use_tqdm=False)
8686
end_time = time.perf_counter()

docs/source/dev/multimodal/multimodal_index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Multi-Modality
88
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
99

1010
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
11-
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.
11+
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
1212

1313
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
1414
by following :ref:`this guide <adding_multimodal_plugin>`.

docs/source/dev/offline_inference/llm_inputs.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
LLM Inputs
22
==========
33

4-
.. autodata:: vllm.inputs.PromptType
4+
.. autodata:: vllm.inputs.PromptInputs
55

66
.. autoclass:: vllm.inputs.TextPrompt
77
:show-inheritance:

docs/source/models/vlm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
2727
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
2828
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
2929

30-
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:
30+
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
3131

3232
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
3333
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

tests/mq_llm_engine/test_error_handling.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ async def test_evil_forward(tmp_socket):
6161

6262
# Throws an error in first forward pass.
6363
with pytest.raises(RAISED_ERROR):
64-
async for _ in client.generate(prompt="Hello my name is",
64+
async for _ in client.generate(inputs="Hello my name is",
6565
sampling_params=SamplingParams(),
6666
request_id=uuid.uuid4()):
6767
pass
6868
assert client.errored
6969

7070
# Engine is errored, should get ENGINE_DEAD_ERROR.
7171
with pytest.raises(MQEngineDeadError):
72-
async for _ in client.generate(prompt="Hello my name is",
72+
async for _ in client.generate(inputs="Hello my name is",
7373
sampling_params=SamplingParams(),
7474
request_id=uuid.uuid4()):
7575
pass
@@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
118118

119119
# Generate call should throw ENGINE_DEAD_ERROR
120120
with pytest.raises(MQEngineDeadError):
121-
async for _ in client.generate(prompt="Hello my name is",
121+
async for _ in client.generate(inputs="Hello my name is",
122122
sampling_params=SamplingParams(),
123123
request_id=uuid.uuid4()):
124124
pass
@@ -165,7 +165,7 @@ async def bad_abort_after_2s():
165165
# with reference to the original KeyError("foo")
166166
with pytest.raises(MQEngineDeadError) as execinfo:
167167
async for _ in client.generate(
168-
prompt="Hello my name is",
168+
inputs="Hello my name is",
169169
sampling_params=SamplingParams(max_tokens=2000),
170170
request_id=uuid.uuid4()):
171171
pass
@@ -190,7 +190,7 @@ async def test_bad_request(tmp_socket):
190190

191191
# Invalid request should fail, but not crash the server.
192192
with pytest.raises(ValueError):
193-
async for _ in client.generate(prompt="Hello my name is",
193+
async for _ in client.generate(inputs="Hello my name is",
194194
sampling_params=SamplingParams(),
195195
request_id="abcd-1",
196196
lora_request=LoRARequest(
@@ -199,7 +199,7 @@ async def test_bad_request(tmp_socket):
199199
pass
200200

201201
# This request should be okay.
202-
async for _ in client.generate(prompt="Hello my name is",
202+
async for _ in client.generate(inputs="Hello my name is",
203203
sampling_params=SamplingParams(),
204204
request_id="abcd-2"):
205205
pass

tests/mq_llm_engine/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ async def generate(
2020
count = 0
2121
async for out in client.generate(
2222
request_id=request_id,
23-
prompt="Hello my name is Robert and",
23+
inputs="Hello my name is Robert and",
2424
sampling_params=SamplingParams(max_tokens=num_tokens,
2525
temperature=0)):
2626

vllm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from vllm.engine.llm_engine import LLMEngine
66
from vllm.entrypoints.llm import LLM
77
from vllm.executor.ray_utils import initialize_ray_cluster
8-
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
8+
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
99
from vllm.model_executor.models import ModelRegistry
1010
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
1111
EmbeddingRequestOutput, RequestOutput)
@@ -19,7 +19,7 @@
1919
"__version_tuple__",
2020
"LLM",
2121
"ModelRegistry",
22-
"PromptType",
22+
"PromptInputs",
2323
"TextPrompt",
2424
"TokensPrompt",
2525
"SamplingParams",

vllm/engine/async_llm_engine.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm.executor.executor_base import ExecutorAsyncBase
1818
from vllm.executor.gpu_executor import GPUExecutorAsync
1919
from vllm.executor.ray_utils import initialize_ray_cluster
20-
from vllm.inputs import PromptType
20+
from vllm.inputs import PromptInputs
2121
from vllm.logger import init_logger
2222
from vllm.lora.request import LoRARequest
2323
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -405,7 +405,7 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
405405
async def add_request_async(
406406
self,
407407
request_id: str,
408-
prompt: PromptType,
408+
inputs: PromptInputs,
409409
params: Union[SamplingParams, PoolingParams],
410410
arrival_time: Optional[float] = None,
411411
lora_request: Optional[LoRARequest] = None,
@@ -420,7 +420,7 @@ async def add_request_async(
420420
arrival_time = time.time()
421421

422422
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
423-
prompt,
423+
inputs,
424424
request_id=request_id,
425425
lora_request=lora_request,
426426
prompt_adapter_request=prompt_adapter_request,
@@ -777,7 +777,7 @@ async def run_engine_loop(engine_ref: ReferenceType):
777777
async def add_request(
778778
self,
779779
request_id: str,
780-
prompt: PromptType,
780+
inputs: PromptInputs,
781781
params: Union[SamplingParams, PoolingParams],
782782
arrival_time: Optional[float] = None,
783783
lora_request: Optional[LoRARequest] = None,
@@ -797,7 +797,7 @@ async def add_request(
797797
stream = self._request_tracker.add_request(
798798
request_id,
799799
verbose=self.log_requests,
800-
prompt=prompt,
800+
inputs=inputs,
801801
params=params,
802802
arrival_time=arrival_time or time.time(),
803803
lora_request=lora_request,
@@ -808,7 +808,7 @@ async def add_request(
808808

809809
async def generate(
810810
self,
811-
prompt: PromptType,
811+
inputs: PromptInputs,
812812
sampling_params: SamplingParams,
813813
request_id: str,
814814
lora_request: Optional[LoRARequest] = None,
@@ -822,7 +822,8 @@ async def generate(
822822
from the LLMEngine to the caller.
823823
824824
Args:
825-
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
825+
inputs: The inputs to the LLM. See
826+
:class:`~vllm.inputs.PromptInputs`
826827
for more details about the format of each input.
827828
sampling_params: The sampling parameters of the request.
828829
request_id: The unique id of the request.
@@ -880,7 +881,7 @@ async def generate(
880881
"""
881882
async for output in await self.add_request(
882883
request_id,
883-
prompt,
884+
inputs,
884885
sampling_params,
885886
lora_request=lora_request,
886887
trace_headers=trace_headers,
@@ -890,7 +891,7 @@ async def generate(
890891

891892
async def encode(
892893
self,
893-
prompt: PromptType,
894+
inputs: PromptInputs,
894895
pooling_params: PoolingParams,
895896
request_id: str,
896897
lora_request: Optional[LoRARequest] = None,
@@ -903,7 +904,8 @@ async def encode(
903904
from the LLMEngine to the caller.
904905
905906
Args:
906-
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
907+
inputs: The inputs to the LLM. See
908+
:class:`~vllm.inputs.PromptInputs`
907909
for more details about the format of each input.
908910
pooling_params: The pooling parameters of the request.
909911
request_id: The unique id of the request.
@@ -957,7 +959,7 @@ async def encode(
957959
"""
958960
async for output in await self.add_request(
959961
request_id,
960-
prompt,
962+
inputs,
961963
pooling_params,
962964
lora_request=lora_request,
963965
trace_headers=trace_headers,

vllm/engine/llm_engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from vllm.executor.gpu_executor import GPUExecutor
3030
from vllm.executor.ray_utils import initialize_ray_cluster
3131
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
32-
InputRegistry, LLMInputs, PromptType)
32+
InputRegistry, LLMInputs, PromptInputs)
3333
from vllm.inputs.preprocess import InputPreprocessor
3434
from vllm.logger import init_logger
3535
from vllm.lora.request import LoRARequest
@@ -689,7 +689,7 @@ def stop_remote_worker_execution_loop(self) -> None:
689689
def add_request(
690690
self,
691691
request_id: str,
692-
prompt: PromptType,
692+
inputs: PromptInputs,
693693
params: Union[SamplingParams, PoolingParams],
694694
arrival_time: Optional[float] = None,
695695
lora_request: Optional[LoRARequest] = None,
@@ -704,7 +704,8 @@ def add_request(
704704
705705
Args:
706706
request_id: The unique ID of the request.
707-
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
707+
inputs: The inputs to the LLM. See
708+
:class:`~vllm.inputs.PromptInputs`
708709
for more details about the format of each input.
709710
params: Parameters for sampling or pooling.
710711
:class:`~vllm.SamplingParams` for text generation.
@@ -744,7 +745,7 @@ def add_request(
744745
arrival_time = time.time()
745746

746747
preprocessed_inputs = self.input_preprocessor.preprocess(
747-
prompt,
748+
inputs,
748749
request_id=request_id,
749750
lora_request=lora_request,
750751
prompt_adapter_request=prompt_adapter_request,

vllm/engine/multiprocessing/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List, Mapping, Optional, Union
44

55
from vllm import PoolingParams
6-
from vllm.inputs import PromptType
6+
from vllm.inputs import PromptInputs
77
from vllm.lora.request import LoRARequest
88
from vllm.outputs import RequestOutput
99
from vllm.prompt_adapter.request import PromptAdapterRequest
@@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError):
2323

2424
@dataclass
2525
class RPCProcessRequest:
26-
prompt: PromptType
26+
inputs: PromptInputs
2727
params: Union[SamplingParams, PoolingParams]
2828
request_id: str
2929
lora_request: Optional[LoRARequest] = None

vllm/engine/multiprocessing/client.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
RPCStartupResponse)
2626
# yapf: enable
2727
from vllm.envs import VLLM_RPC_TIMEOUT
28-
from vllm.inputs import PromptType
28+
from vllm.inputs import PromptInputs
2929
from vllm.logger import init_logger
3030
from vllm.lora.request import LoRARequest
3131
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
@@ -375,7 +375,7 @@ def dead_error(self) -> BaseException:
375375

376376
def generate(
377377
self,
378-
prompt: PromptType,
378+
inputs: PromptInputs,
379379
sampling_params: SamplingParams,
380380
request_id: str,
381381
lora_request: Optional[LoRARequest] = None,
@@ -389,7 +389,8 @@ def generate(
389389
from the LLMEngine to the caller.
390390
391391
Args:
392-
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
392+
inputs: The inputs to the LLM. See
393+
:class:`~vllm.inputs.PromptInputs`
393394
for more details about the format of each input.
394395
sampling_params: The sampling parameters of the request.
395396
request_id: The unique id of the request.
@@ -398,13 +399,13 @@ def generate(
398399
prompt_adapter_request: Prompt Adapter request to use
399400
for generation, if any.
400401
"""
401-
return self._process_request(prompt, sampling_params, request_id,
402+
return self._process_request(inputs, sampling_params, request_id,
402403
lora_request, trace_headers,
403404
prompt_adapter_request)
404405

405406
def encode(
406407
self,
407-
prompt: PromptType,
408+
inputs: PromptInputs,
408409
pooling_params: PoolingParams,
409410
request_id: str,
410411
lora_request: Optional[LoRARequest] = None,
@@ -417,7 +418,8 @@ def encode(
417418
from the LLMEngine to the caller.
418419
419420
Args:
420-
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
421+
inputs: The inputs to the LLM. See
422+
:class:`~vllm.inputs.PromptInputs`
421423
for more details about the format of each input.
422424
pooling_params: The pooling parameters of the request.
423425
request_id: The unique id of the request.
@@ -428,12 +430,12 @@ def encode(
428430
The output `EmbeddingRequestOutput` objects from the LLMEngine
429431
for the request.
430432
"""
431-
return self._process_request(prompt, pooling_params, request_id,
433+
return self._process_request(inputs, pooling_params, request_id,
432434
lora_request, trace_headers)
433435

434436
async def _process_request(
435437
self,
436-
prompt: PromptType,
438+
inputs: PromptInputs,
437439
params: Union[SamplingParams, PoolingParams],
438440
request_id: str,
439441
lora_request: Optional[LoRARequest] = None,
@@ -466,7 +468,7 @@ async def _process_request(
466468

467469
request_bytes = pickle.dumps(
468470
RPCProcessRequest(
469-
prompt=prompt,
471+
inputs=inputs,
470472
params=params,
471473
request_id=request_id,
472474
lora_request=lora_request,

vllm/engine/multiprocessing/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def _handle_process_request(self, request: RPCProcessRequest):
252252
try:
253253
self.engine.add_request(
254254
request_id=request_id,
255-
prompt=request.prompt,
255+
inputs=request.inputs,
256256
params=request.params,
257257
lora_request=request.lora_request,
258258
trace_headers=request.trace_headers,

0 commit comments

Comments
 (0)