Skip to content

Commit ed18ba7

Browse files
committed
bert_with_rope
1 parent 5b95dd3 commit ed18ba7

File tree

6 files changed

+330
-348
lines changed

6 files changed

+330
-348
lines changed

tests/models/language/pooling/test_nomic.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from tests.models.utils import EmbedModelInfo
5+
from ...utils import EmbedModelInfo, run_embedding_correctness_test
66

77
MODELS = [
88
EmbedModelInfo("nomic-ai/nomic-embed-text-v1",
@@ -21,11 +21,26 @@
2121

2222

2323
@pytest.mark.parametrize("model_info", MODELS)
24-
def test_models(
25-
hf_runner,
26-
vllm_runner,
27-
model_info: EmbedModelInfo,
28-
monkeypatch,
29-
) -> None:
24+
def test_models(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None:
3025
from .mteb_utils import mteb_test_embed_models
3126
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
27+
28+
29+
@pytest.mark.parametrize("model_info", MODELS)
30+
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo,
31+
example_prompts) -> None:
32+
if not model_info.enable_test:
33+
pytest.skip("Skipping test.")
34+
35+
with vllm_runner(model_info.name,
36+
task="embed",
37+
dtype=model_info.dtype,
38+
max_model_len=None) as vllm_model:
39+
vllm_outputs = vllm_model.encode(example_prompts)
40+
41+
with hf_runner(
42+
model_info.name,
43+
dtype=model_info.dtype,
44+
is_sentence_transformer=True,
45+
) as hf_model:
46+
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)

tests/models/language/pooling/test_snowflake_arctic_embed.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from tests.models.utils import EmbedModelInfo
5+
from ...utils import EmbedModelInfo, run_embedding_correctness_test
66

77
MODELS = [
88
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
@@ -41,11 +41,34 @@
4141

4242

4343
@pytest.mark.parametrize("model_info", MODELS)
44-
def test_models(
44+
def test_models_mteb(
4545
hf_runner,
4646
vllm_runner,
4747
model_info: EmbedModelInfo,
48-
monkeypatch,
4948
) -> None:
5049
from .mteb_utils import mteb_test_embed_models
5150
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
51+
52+
53+
@pytest.mark.parametrize("model_info", MODELS)
54+
def test_models_correctness(
55+
hf_runner,
56+
vllm_runner,
57+
model_info: EmbedModelInfo,
58+
example_prompts,
59+
) -> None:
60+
if not model_info.enable_test:
61+
pytest.skip("Skipping test.")
62+
63+
with vllm_runner(model_info.name,
64+
task="embed",
65+
dtype=model_info.dtype,
66+
max_model_len=None) as vllm_model:
67+
vllm_outputs = vllm_model.encode(example_prompts)
68+
69+
with hf_runner(
70+
model_info.name,
71+
dtype=model_info.dtype,
72+
is_sentence_transformer=True,
73+
) as hf_model:
74+
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)

0 commit comments

Comments
 (0)