Skip to content

Commit e40f8a0

Browse files
xyang16garg-amit
authored andcommitted
[Model] Support Gemma2 embedding model (vllm-project#9004)
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent 8b39bbc commit e40f8a0

File tree

5 files changed

+99
-3
lines changed

5 files changed

+99
-3
lines changed

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def __init__(
277277
SentenceTransformer(
278278
model_name,
279279
device="cpu",
280+
trust_remote_code=True,
280281
).to(dtype=torch_dtype))
281282
else:
282283
model_kwargs = model_kwargs if model_kwargs is not None else {}

tests/models/embedding/language/test_embedding.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
22
3-
Run `pytest tests/models/test_llama_embedding.py`.
3+
Run `pytest tests/models/embedding/language/test_embedding.py`.
44
"""
55
import pytest
66
import torch
77
import torch.nn.functional as F
88

99
MODELS = [
1010
"intfloat/e5-mistral-7b-instruct",
11+
"BAAI/bge-multilingual-gemma2",
1112
]
1213

1314

@@ -28,6 +29,14 @@ def test_models(
2829
model: str,
2930
dtype: str,
3031
) -> None:
32+
# The example_prompts has ending "\n", for example:
33+
# "Write a short story about a robot that dreams for the first time.\n"
34+
# sentence_transformers will strip the input texts, see:
35+
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
36+
# This makes the input_ids different between hf_model and vllm_model.
37+
# So we need to strip the input texts to avoid test failing.
38+
example_prompts = [str(s).strip() for s in example_prompts]
39+
3140
with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model:
3241
hf_outputs = hf_model.encode(example_prompts)
3342

vllm/model_executor/models/gemma2.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,14 @@ def forward(
278278
kv_caches: List[torch.Tensor],
279279
attn_metadata: AttentionMetadata,
280280
intermediate_tensors: Optional[IntermediateTensors],
281+
inputs_embeds: Optional[torch.Tensor] = None,
281282
) -> Union[torch.Tensor, IntermediateTensors]:
282283
if get_pp_group().is_first_rank:
283-
hidden_states = self.embed_tokens(input_ids)
284+
if inputs_embeds is not None:
285+
hidden_states = inputs_embeds
286+
else:
287+
hidden_states = self.embed_tokens(input_ids)
284288
hidden_states *= self.normalizer
285-
286289
residual = None
287290
else:
288291
assert intermediate_tensors is not None
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import Iterable, List, Optional, Tuple
2+
3+
import torch
4+
from torch import nn
5+
6+
from vllm.attention import AttentionMetadata
7+
from vllm.model_executor.layers.pooler import Pooler, PoolingType
8+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
9+
from vllm.model_executor.models.gemma2 import Gemma2Model
10+
from vllm.model_executor.pooling_metadata import PoolingMetadata
11+
from vllm.sequence import IntermediateTensors, PoolerOutput
12+
13+
14+
class Gemma2EmbeddingModel(nn.Module):
15+
"""A model that uses Gemma2 with additional embedding functionalities.
16+
17+
This class encapsulates the Gemma2Model and provides an interface for
18+
embedding operations and customized pooling functions.
19+
20+
Attributes:
21+
model: An instance of Gemma2Model used for forward operations.
22+
_pooler: An instance of Pooler used for pooling operations.
23+
"""
24+
25+
def __init__(
26+
self,
27+
**kwargs,
28+
) -> None:
29+
super().__init__()
30+
self.model = Gemma2Model(**kwargs)
31+
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
32+
33+
def forward(
34+
self,
35+
input_ids: Optional[torch.Tensor],
36+
positions: torch.Tensor,
37+
kv_caches: List[torch.Tensor],
38+
attn_metadata: AttentionMetadata,
39+
intermediate_tensors: Optional[IntermediateTensors] = None,
40+
inputs_embeds: Optional[torch.Tensor] = None,
41+
) -> torch.Tensor:
42+
return self.model.forward(input_ids, positions, kv_caches,
43+
attn_metadata, intermediate_tensors,
44+
inputs_embeds)
45+
46+
def pooler(
47+
self,
48+
hidden_states: torch.Tensor,
49+
pooling_metadata: PoolingMetadata,
50+
) -> Optional[PoolerOutput]:
51+
return self._pooler(hidden_states, pooling_metadata)
52+
53+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
54+
stacked_params_mapping = [
55+
# (param_name, shard_name, shard_id)
56+
("qkv_proj", "q_proj", "q"),
57+
("qkv_proj", "k_proj", "k"),
58+
("qkv_proj", "v_proj", "v"),
59+
("gate_up_proj", "gate_proj", 0),
60+
("gate_up_proj", "up_proj", 1),
61+
]
62+
params_dict = dict(self.model.named_parameters())
63+
for name, loaded_weight in weights:
64+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
65+
if weight_name not in name:
66+
continue
67+
name = name.replace(weight_name, param_name)
68+
# Skip loading extra bias for GPTQ models.
69+
if name.endswith(".bias") and name not in params_dict:
70+
continue
71+
param = params_dict[name]
72+
weight_loader = param.weight_loader
73+
weight_loader(param, loaded_weight, shard_id)
74+
break
75+
else:
76+
# Skip loading extra bias for GPTQ models.
77+
if name.endswith(".bias") and name not in params_dict:
78+
continue
79+
param = params_dict[name]
80+
weight_loader = getattr(param, "weight_loader",
81+
default_weight_loader)
82+
weight_loader(param, loaded_weight)

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
_EMBEDDING_MODELS = {
8484
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
8585
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
86+
"Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"),
8687
}
8788

8889
_MULTIMODAL_MODELS = {

0 commit comments

Comments
 (0)