Skip to content

[CI] Add mteb testing to test the accuracy of the embedding model #17175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ steps:
- vllm/entrypoints/openai/
- vllm/model_executor/models/whisper.py
commands: # LMEval+Transcription WER check
- pip install mteb
- pytest -s entrypoints/openai/correctness/

- label: Encoder Decoder tests # 5min
Expand Down
4 changes: 2 additions & 2 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ dataproperty==1.0.1
# via
# pytablewriter
# tabledata
datasets==3.0.2
datasets==2.21.0
# via
# evaluate
# lm-eval
Expand Down Expand Up @@ -151,7 +151,7 @@ frozenlist==1.5.0
# aiohttp
# aiosignal
# ray
fsspec==2024.9.0
fsspec==2024.6.1
# via
# datasets
# evaluate
Expand Down
42 changes: 42 additions & 0 deletions tests/entrypoints/openai/correctness/test_mteb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
import math
import os

import pytest

from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS,
OpenAIClientMtebEncoder,
run_mteb_embed_task,
run_mteb_embed_task_st)
from tests.utils import RemoteOpenAIServer

os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"

MODEL_NAME = "BAAI/bge-m3"
DTYPE = "float16"
MAIN_SCORE = 0.7873427091972599


@pytest.fixture(scope="module")
def server():
args = [
"--task", "embed", "--dtype", DTYPE, "--enforce-eager",
"--max-model-len", "512"
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


def test_mteb(server):
client = server.get_client()
encoder = OpenAIClientMtebEncoder(MODEL_NAME, client)
vllm_main_score = run_mteb_embed_task(encoder, MTEB_EMBED_TASKS)
st_main_score = MAIN_SCORE or run_mteb_embed_task_st(
MODEL_NAME, MTEB_EMBED_TASKS)

print("VLLM main score: ", vllm_main_score)
print("SentenceTransformer main score: ", st_main_score)
print("Difference: ", st_main_score - vllm_main_score)

assert math.isclose(st_main_score, vllm_main_score, rel_tol=1e-4)
67 changes: 67 additions & 0 deletions tests/models/language/pooling/mteb_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence

import mteb
import numpy as np

MTEB_EMBED_TASKS = ["STS12"]


class VllmMtebEncoder(mteb.Encoder):

def __init__(self, vllm_model):
super().__init__()
self.model = vllm_model
self.rng = np.random.default_rng(seed=42)

def encode(
self,
sentences: Sequence[str],
*args,
**kwargs,
) -> np.ndarray:
# Hoping to discover potential scheduling
# issues by randomizing the order.
r = self.rng.permutation(len(sentences))
sentences = [sentences[i] for i in r]
outputs = self.model.encode(sentences, use_tqdm=False)
embeds = np.array(outputs)
embeds = embeds[np.argsort(r)]
return embeds


class OpenAIClientMtebEncoder(mteb.Encoder):

def __init__(self, model_name: str, client):
super().__init__()
self.model_name = model_name
self.client = client
self.rng = np.random.default_rng(seed=42)

def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray:
# Hoping to discover potential scheduling
# issues by randomizing the order.
r = self.rng.permutation(len(sentences))
sentences = [sentences[i] for i in r]

embeddings = self.client.embeddings.create(model=self.model_name,
input=sentences)
outputs = [d.embedding for d in embeddings.data]
embeds = np.array(outputs)
embeds = embeds[np.argsort(r)]
return embeds


def run_mteb_embed_task(encoder, tasks):
tasks = mteb.get_tasks(tasks=tasks)
evaluation = mteb.MTEB(tasks=tasks)
results = evaluation.run(encoder, verbosity=0, output_folder=None)

main_score = results[0].scores["test"][0]["main_score"]
return main_score


def run_mteb_embed_task_st(model_name, tasks):
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
return run_mteb_embed_task(model, tasks)
34 changes: 13 additions & 21 deletions tests/models/language/pooling/test_snowflake_arctic_embed.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
import math

import pytest

from ...utils import EmbedModelInfo, check_embeddings_close
from tests.models.utils import EmbedModelInfo

EMBEDDING_PROMPTS = [
'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!',
'Mexico City of Course!'
]
from .mteb_utils import MTEB_EMBED_TASKS, VllmMtebEncoder, run_mteb_embed_task

MODELS = [
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
Expand Down Expand Up @@ -45,35 +44,29 @@


@pytest.mark.parametrize("model_info", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model_info: EmbedModelInfo,
dtype: str,
monkeypatch,
) -> None:
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.")

example_prompts = example_prompts + EMBEDDING_PROMPTS

vllm_extra_kwargs = {
"hf_overrides": {
"is_matryoshka": model_info.is_matryoshka
}
}

with hf_runner(model_info.name, dtype=dtype,
is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)
with hf_runner(model_info.name, is_sentence_transformer=True) as hf_model:
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)

with vllm_runner(model_info.name,
task="embed",
dtype=dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:

Expand All @@ -84,12 +77,11 @@ def test_models(
assert (model_info.architecture
in vllm_model.model.llm_engine.model_config.architectures)

vllm_outputs = vllm_model.encode(example_prompts)
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
MTEB_EMBED_TASKS)

print("VLLM main score: ", vllm_main_score)
print("SentenceTransformer main score: ", st_main_score)
print("Difference: ", st_main_score - vllm_main_score)

check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
assert math.isclose(st_main_score, vllm_main_score, rel_tol=1e-4)
20 changes: 14 additions & 6 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,9 @@ def _build_model(self,
assert not config.mlp_fc2_bias
assert not config.qkv_proj_bias

assert config.rotary_emb_scale_base is None
assert not config.rotary_emb_interleaved

config.layer_norm_eps = config.layer_norm_epsilon
config.position_embedding_type = "rotary"
config.intermediate_size = config.n_inner
Expand All @@ -649,17 +652,21 @@ def _build_model(self,
config.num_hidden_layers = config.n_layer

head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = head_dim * config.rotary_emb_fraction
rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"rotary_dim": rotary_emb_dim,
"max_position": config.max_trained_positions,
"base": config.rotary_emb_base,
"rope_scaling": {
"rope_type": "dynamic",
"factor": config.rotary_scaling_factor
}
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}

# we ignore config.rotary_scaling_factor so that for datasets shorter
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_scaling.
# See #17175

return BertModel(vllm_config=vllm_config,
prefix=prefix,
bias=False,
Expand Down Expand Up @@ -695,6 +702,7 @@ def _build_model(self,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}

model = BertModel(vllm_config=vllm_config,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _build_model(self,
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rotary_emb_base,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}

Expand Down