Skip to content

Commit 5d8cec2

Browse files
DarkLight1337LeiWang1999
authored andcommitted
[Model] Explicit interface for vLLM models and support OOT embedding models (vllm-project#9108)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 152e363 commit 5d8cec2

File tree

10 files changed

+342
-37
lines changed

10 files changed

+342
-37
lines changed

tests/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,7 @@ def num_gpus_available():
871871
temp_dir = tempfile.gettempdir()
872872
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
873873
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
874+
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")
874875

875876

876877
@pytest.fixture
@@ -909,3 +910,22 @@ def dummy_llava_path():
909910
with open(json_path, "w") as f:
910911
json.dump(config, f)
911912
return _dummy_llava_path
913+
914+
915+
@pytest.fixture
916+
def dummy_gemma2_embedding_path():
917+
json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
918+
if not os.path.exists(_dummy_gemma2_embedding_path):
919+
snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
920+
local_dir=_dummy_gemma2_embedding_path,
921+
ignore_patterns=[
922+
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
923+
"*.msgpack"
924+
])
925+
assert os.path.exists(json_path)
926+
with open(json_path, "r") as f:
927+
config = json.load(f)
928+
config["architectures"] = ["MyGemma2Embedding"]
929+
with open(json_path, "w") as f:
930+
json.dump(config, f)
931+
return _dummy_gemma2_embedding_path

tests/models/test_oot_registration.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from vllm import LLM, SamplingParams
5+
from vllm import LLM, PoolingParams, SamplingParams
66
from vllm.assets.image import ImageAsset
77

88
from ..utils import fork_new_process_for_each_test
@@ -17,7 +17,7 @@ def test_plugin(dummy_opt_path):
1717

1818

1919
@fork_new_process_for_each_test
20-
def test_oot_registration(dummy_opt_path):
20+
def test_oot_registration_text_generation(dummy_opt_path):
2121
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
2222
prompts = ["Hello, my name is", "The text does not matter"]
2323
sampling_params = SamplingParams(temperature=0)
@@ -32,11 +32,23 @@ def test_oot_registration(dummy_opt_path):
3232
assert rest == ""
3333

3434

35+
@fork_new_process_for_each_test
36+
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
37+
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
38+
prompts = ["Hello, my name is", "The text does not matter"]
39+
sampling_params = PoolingParams()
40+
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
41+
outputs = llm.encode(prompts, sampling_params)
42+
43+
for output in outputs:
44+
assert all(v == 0 for v in output.outputs.embedding)
45+
46+
3547
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
3648

3749

3850
@fork_new_process_for_each_test
39-
def test_oot_multimodal_registration(dummy_llava_path):
51+
def test_oot_registration_multimodal(dummy_llava_path):
4052
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
4153
prompts = [{
4254
"prompt": "What's in the image?<image>",

tests/models/test_registry.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
import pytest
44
import torch.cuda
55

6-
from vllm.model_executor.models import ModelRegistry
6+
from vllm.model_executor.models import (is_embedding_model,
7+
is_text_generation_model,
8+
supports_multimodal)
9+
from vllm.model_executor.models.registry import (_EMBEDDING_MODELS,
10+
_MULTIMODAL_MODELS,
11+
_SPECULATIVE_DECODING_MODELS,
12+
_TEXT_GENERATION_MODELS,
13+
ModelRegistry)
714
from vllm.platforms import current_platform
815

916
from ..utils import fork_new_process_for_each_test
@@ -12,7 +19,20 @@
1219
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
1320
def test_registry_imports(model_arch):
1421
# Ensure all model classes can be imported successfully
15-
ModelRegistry.resolve_model_cls(model_arch)
22+
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)
23+
24+
if model_arch in _SPECULATIVE_DECODING_MODELS:
25+
pass # Ignore these models which do not have a unified format
26+
else:
27+
assert is_text_generation_model(model_cls) is (
28+
model_arch in _TEXT_GENERATION_MODELS
29+
or model_arch in _MULTIMODAL_MODELS)
30+
31+
assert is_embedding_model(model_cls) is (model_arch
32+
in _EMBEDDING_MODELS)
33+
34+
assert supports_multimodal(model_cls) is (model_arch
35+
in _MULTIMODAL_MODELS)
1636

1737

1838
@fork_new_process_for_each_test

tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ def register():
99
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)
1010

1111
# Test passing lazy model
12+
if "MyGemma2Embedding" not in ModelRegistry.get_supported_archs():
13+
ModelRegistry.register_model(
14+
"MyGemma2Embedding",
15+
"vllm_add_dummy_model.my_gemma_embedding:MyGemma2Embedding",
16+
)
17+
1218
if "MyLlava" not in ModelRegistry.get_supported_archs():
1319
ModelRegistry.register_model("MyLlava",
1420
"vllm_add_dummy_model.my_llava:MyLlava")
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import List, Optional, Union
2+
3+
import torch
4+
5+
from vllm.attention import AttentionMetadata
6+
from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel
7+
from vllm.sequence import IntermediateTensors
8+
9+
10+
class MyGemma2Embedding(Gemma2EmbeddingModel):
11+
12+
def forward(
13+
self,
14+
input_ids: torch.Tensor,
15+
positions: torch.Tensor,
16+
kv_caches: List[torch.Tensor],
17+
attn_metadata: AttentionMetadata,
18+
intermediate_tensors: Optional[IntermediateTensors] = None,
19+
inputs_embeds: Optional[torch.Tensor] = None,
20+
) -> Union[torch.Tensor, IntermediateTensors]:
21+
hidden_states = super().forward(
22+
input_ids,
23+
positions,
24+
kv_caches,
25+
attn_metadata,
26+
intermediate_tensors=intermediate_tensors,
27+
inputs_embeds=inputs_embeds,
28+
)
29+
30+
if isinstance(hidden_states, IntermediateTensors):
31+
return hidden_states
32+
33+
# Return all-zero embeddings
34+
return torch.zeros_like(hidden_states)

vllm/model_executor/models/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
22
SupportsPP, has_inner_state, supports_lora,
33
supports_multimodal, supports_pp)
4+
from .interfaces_base import (VllmModelForEmbedding,
5+
VllmModelForTextGeneration, is_embedding_model,
6+
is_text_generation_model)
47
from .registry import ModelRegistry
58

69
__all__ = [
710
"ModelRegistry",
11+
"VllmModelForEmbedding",
12+
"is_embedding_model",
13+
"VllmModelForTextGeneration",
14+
"is_text_generation_model",
815
"HasInnerState",
916
"has_inner_state",
1017
"SupportsLoRA",

vllm/model_executor/models/interfaces.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
import inspect
21
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
32
Protocol, Type, Union, overload, runtime_checkable)
43

54
import torch
65
from typing_extensions import TypeIs
76

87
from vllm.logger import init_logger
8+
from vllm.utils import supports_kw
99

1010
if TYPE_CHECKING:
11-
from vllm.attention import AttentionMetadata
1211
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
1312
from vllm.sequence import IntermediateTensors
1413

@@ -142,9 +141,7 @@ def supports_lora(
142141
return result
143142

144143

145-
def _supports_lora(
146-
model: Union[Type[object], object],
147-
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
144+
def _supports_lora(model: Union[Type[object], object]) -> bool:
148145
if isinstance(model, type):
149146
return isinstance(model, _SupportsLoRAType)
150147

@@ -175,10 +172,7 @@ def make_empty_intermediate_tensors(
175172

176173
def forward(
177174
self,
178-
input_ids: torch.Tensor,
179-
position_ids: torch.Tensor,
180-
kv_caches: List[torch.Tensor],
181-
attn_metadata: "AttentionMetadata",
175+
*,
182176
intermediate_tensors: Optional["IntermediateTensors"],
183177
) -> Union[torch.Tensor, "IntermediateTensors"]:
184178
"""
@@ -205,10 +199,7 @@ def make_empty_intermediate_tensors(
205199

206200
def forward(
207201
self,
208-
input_ids: torch.Tensor,
209-
position_ids: torch.Tensor,
210-
kv_caches: List[torch.Tensor],
211-
attn_metadata: "AttentionMetadata",
202+
*,
212203
intermediate_tensors: Optional["IntermediateTensors"],
213204
) -> Union[torch.Tensor, "IntermediateTensors"]:
214205
...
@@ -257,24 +248,19 @@ def supports_pp(
257248
return supports_attributes and supports_inspect
258249

259250

260-
def _supports_pp_attributes(
261-
model: Union[Type[object], object],
262-
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
251+
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
263252
if isinstance(model, type):
264253
return isinstance(model, _SupportsPPType)
265254

266255
return isinstance(model, SupportsPP)
267256

268257

269-
def _supports_pp_inspect(
270-
model: Union[Type[object], object],
271-
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
258+
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
272259
model_forward = getattr(model, "forward", None)
273260
if not callable(model_forward):
274261
return False
275262

276-
forward_params = inspect.signature(model_forward).parameters
277-
return "intermediate_tensors" in forward_params
263+
return supports_kw(model_forward, "intermediate_tensors")
278264

279265

280266
@runtime_checkable

0 commit comments

Comments
 (0)