Skip to content

Commit e4b8713

Browse files
authored
[New Model]: nomic-embed-text-v2-moe (#17785)
1 parent 06c0922 commit e4b8713

File tree

9 files changed

+899
-364
lines changed

9 files changed

+899
-364
lines changed

docs/source/models/supported_models.md

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ Specified using `--task embed`.
622622
* [PP](#distributed-serving)
623623
- * `BertModel`
624624
* BERT-based
625-
* `BAAI/bge-base-en-v1.5`, etc.
625+
* `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc.
626626
*
627627
*
628628
- * `Gemma2Model`
@@ -635,6 +635,16 @@ Specified using `--task embed`.
635635
* `parasail-ai/GritLM-7B-vllm`.
636636
* ✅︎
637637
* ✅︎
638+
- * `GteModel`
639+
* GteModel
640+
* `Snowflake/snowflake-arctic-embed-m-v2.0`.
641+
*
642+
*
643+
- * `NomicBertModel`
644+
* NomicBertModel
645+
* `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc.
646+
*
647+
*
638648
- * `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc.
639649
* Llama-based
640650
* `intfloat/e5-mistral-7b-instruct`, etc.
@@ -647,12 +657,12 @@ Specified using `--task embed`.
647657
* ✅︎
648658
- * `RobertaModel`, `RobertaForMaskedLM`
649659
* RoBERTa-based
650-
* `sentence-transformers/all-roberta-large-v1`, `sentence-transformers/all-roberta-large-v1`, etc.
660+
* `sentence-transformers/all-roberta-large-v1`, etc.
651661
*
652662
*
653663
- * `XLMRobertaModel`
654664
* XLM-RoBERTa-based
655-
* `intfloat/multilingual-e5-large`, `jinaai/jina-reranker-v2-base-multilingual`, etc.
665+
* `intfloat/multilingual-e5-large`, `jinaai/jina-reranker-v2-base-multilingual`, `Snowflake/snowflake-arctic-embed-l-v2.0`, `jinaai/jina-embeddings-v3`(see note), etc.
656666
*
657667
*
658668
:::
@@ -670,6 +680,10 @@ For both the 1.5B and 7B variants, you also need to enable `--trust-remote-code`
670680
See [relevant issue on HF Transformers](https://github.com/huggingface/transformers/issues/34882).
671681
:::
672682

683+
:::{note}
684+
`jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights.
685+
:::
686+
673687
If your model is not in the above list, we will try to automatically convert the model using
674688
{func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings
675689
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import math
3+
from collections.abc import Sequence
4+
5+
import mteb
6+
import numpy as np
7+
import pytest
8+
9+
from tests.models.utils import EmbedModelInfo
10+
11+
# Most models on the STS12 task (See #17175):
12+
# - Model implementation and minor changes in tensor dtype
13+
# results in differences less than 1e-4
14+
# - Different model results in differences more than 1e-3
15+
# 1e-4 is a good tolerance threshold
16+
MTEB_EMBED_TASKS = ["STS12"]
17+
MTEB_EMBED_TOL = 1e-4
18+
19+
20+
class VllmMtebEncoder(mteb.Encoder):
21+
22+
def __init__(self, vllm_model):
23+
super().__init__()
24+
self.model = vllm_model
25+
self.rng = np.random.default_rng(seed=42)
26+
27+
def encode(
28+
self,
29+
sentences: Sequence[str],
30+
*args,
31+
**kwargs,
32+
) -> np.ndarray:
33+
# Hoping to discover potential scheduling
34+
# issues by randomizing the order.
35+
r = self.rng.permutation(len(sentences))
36+
sentences = [sentences[i] for i in r]
37+
outputs = self.model.encode(sentences, use_tqdm=False)
38+
embeds = np.array(outputs)
39+
embeds = embeds[np.argsort(r)]
40+
return embeds
41+
42+
43+
class OpenAIClientMtebEncoder(mteb.Encoder):
44+
45+
def __init__(self, model_name: str, client):
46+
super().__init__()
47+
self.model_name = model_name
48+
self.client = client
49+
self.rng = np.random.default_rng(seed=42)
50+
51+
def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray:
52+
# Hoping to discover potential scheduling
53+
# issues by randomizing the order.
54+
r = self.rng.permutation(len(sentences))
55+
sentences = [sentences[i] for i in r]
56+
57+
embeddings = self.client.embeddings.create(model=self.model_name,
58+
input=sentences)
59+
outputs = [d.embedding for d in embeddings.data]
60+
embeds = np.array(outputs)
61+
embeds = embeds[np.argsort(r)]
62+
return embeds
63+
64+
65+
def run_mteb_embed_task(encoder, tasks):
66+
tasks = mteb.get_tasks(tasks=tasks)
67+
evaluation = mteb.MTEB(tasks=tasks)
68+
results = evaluation.run(encoder, verbosity=0, output_folder=None)
69+
70+
main_score = results[0].scores["test"][0]["main_score"]
71+
return main_score
72+
73+
74+
def run_mteb_embed_task_st(model_name, tasks):
75+
from sentence_transformers import SentenceTransformer
76+
model = SentenceTransformer(model_name)
77+
return run_mteb_embed_task(model, tasks)
78+
79+
80+
def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo):
81+
if not model_info.enable_test:
82+
# A model family has many models with the same architecture,
83+
# and we don't need to test each one.
84+
pytest.skip("Skipping test.")
85+
86+
with vllm_runner(model_info.name,
87+
task="embed",
88+
max_model_len=None,
89+
dtype=model_info.dtype) as vllm_model:
90+
91+
if model_info.architecture:
92+
assert (model_info.architecture
93+
in vllm_model.model.llm_engine.model_config.architectures)
94+
95+
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
96+
MTEB_EMBED_TASKS)
97+
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
98+
model_dtype = getattr(
99+
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
100+
vllm_dtype)
101+
102+
with hf_runner(model_info.name,
103+
is_sentence_transformer=True,
104+
dtype=model_dtype) as hf_model:
105+
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
106+
107+
print("VLLM:", vllm_dtype, vllm_main_score)
108+
print("SentenceTransformer:", model_dtype, st_main_score)
109+
print("Difference:", st_main_score - vllm_main_score)
110+
111+
assert math.isclose(st_main_score, vllm_main_score, rel_tol=MTEB_EMBED_TOL)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
5+
from ...utils import EmbedModelInfo, run_embedding_correctness_test
6+
7+
MODELS = [
8+
EmbedModelInfo("nomic-ai/nomic-embed-text-v1",
9+
architecture="NomicBertModel",
10+
dtype="float32",
11+
enable_test=True),
12+
EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5",
13+
architecture="NomicBertModel",
14+
dtype="float32",
15+
enable_test=False),
16+
EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
17+
architecture="NomicBertModel",
18+
dtype="float32",
19+
enable_test=True)
20+
]
21+
22+
23+
@pytest.mark.parametrize("model_info", MODELS)
24+
def test_models_mteb(hf_runner, vllm_runner,
25+
model_info: EmbedModelInfo) -> None:
26+
from .mteb_utils import mteb_test_embed_models
27+
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
28+
29+
30+
@pytest.mark.parametrize("model_info", MODELS)
31+
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
32+
example_prompts) -> None:
33+
if not model_info.enable_test:
34+
pytest.skip("Skipping test.")
35+
36+
with vllm_runner(model_info.name,
37+
task="embed",
38+
dtype=model_info.dtype,
39+
max_model_len=None) as vllm_model:
40+
vllm_outputs = vllm_model.encode(example_prompts)
41+
42+
with hf_runner(
43+
model_info.name,
44+
dtype=model_info.dtype,
45+
is_sentence_transformer=True,
46+
) as hf_model:
47+
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
import pytest
32

4-
from ...utils import EmbedModelInfo, check_embeddings_close
3+
import pytest
54

6-
EMBEDDING_PROMPTS = [
7-
'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!',
8-
'Mexico City of Course!'
9-
]
5+
from ...utils import EmbedModelInfo, run_embedding_correctness_test
106

117
MODELS = [
128
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
@@ -45,51 +41,34 @@
4541

4642

4743
@pytest.mark.parametrize("model_info", MODELS)
48-
@pytest.mark.parametrize("dtype", ["half"])
49-
def test_models(
44+
def test_models_mteb(
5045
hf_runner,
5146
vllm_runner,
52-
example_prompts,
5347
model_info: EmbedModelInfo,
54-
dtype: str,
55-
monkeypatch,
5648
) -> None:
57-
if not model_info.enable_test:
58-
# A model family has many models with the same architecture,
59-
# and we don't need to test each one.
60-
pytest.skip("Skipping test.")
61-
62-
example_prompts = example_prompts + EMBEDDING_PROMPTS
49+
from .mteb_utils import mteb_test_embed_models
50+
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
6351

64-
vllm_extra_kwargs = {
65-
"hf_overrides": {
66-
"is_matryoshka": model_info.is_matryoshka
67-
}
68-
}
6952

70-
with hf_runner(model_info.name, dtype=dtype,
71-
is_sentence_transformer=True) as hf_model:
72-
hf_outputs = hf_model.encode(example_prompts)
53+
@pytest.mark.parametrize("model_info", MODELS)
54+
def test_models_correctness(
55+
hf_runner,
56+
vllm_runner,
57+
model_info: EmbedModelInfo,
58+
example_prompts,
59+
) -> None:
60+
if not model_info.enable_test:
61+
pytest.skip("Skipping test.")
7362

7463
with vllm_runner(model_info.name,
7564
task="embed",
76-
dtype=dtype,
77-
max_model_len=None,
78-
**vllm_extra_kwargs) as vllm_model:
79-
80-
assert (vllm_model.model.llm_engine.model_config.is_matryoshka ==
81-
model_info.is_matryoshka)
82-
83-
if model_info.architecture:
84-
assert (model_info.architecture
85-
in vllm_model.model.llm_engine.model_config.architectures)
86-
65+
dtype=model_info.dtype,
66+
max_model_len=None) as vllm_model:
8767
vllm_outputs = vllm_model.encode(example_prompts)
8868

89-
check_embeddings_close(
90-
embeddings_0_lst=hf_outputs,
91-
embeddings_1_lst=vllm_outputs,
92-
name_0="hf",
93-
name_1="vllm",
94-
tol=1e-2,
95-
)
69+
with hf_runner(
70+
model_info.name,
71+
dtype=model_info.dtype,
72+
is_sentence_transformer=True,
73+
) as hf_model:
74+
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)

tests/models/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,10 @@ def matryoshka_fy(tensor: torch.Tensor, dimensions: int):
332332

333333
class EmbedModelInfo(NamedTuple):
334334
name: str
335-
is_matryoshka: bool
335+
is_matryoshka: bool = False
336336
matryoshka_dimensions: Optional[list[int]] = None
337337
architecture: str = ""
338+
dtype: str = "auto"
338339
enable_test: bool = True
339340

340341

0 commit comments

Comments
 (0)