Skip to content

Commit

Permalink
Merge pull request #148 from McGill-NLP/mteb-eval-custom
Browse files Browse the repository at this point in the history
mteb eval custom script
  • Loading branch information
vaibhavad authored Oct 8, 2024
2 parents b7b504d + c7f2fc6 commit 250292a
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ LLM2Vec is a simple recipe to convert decoder-only LLMs into text encoders. It c

**************************** **Updates** ****************************

* 03/10: Added support for latest transformer versions, which support Llama 3.1, 3.2 and other latest models
* 03/10: Added support for latest transformer versions, which support Llama 3.1, 3.2 and other latest models. Expanded support to evaluate any LLM2vec model, check [mteb_eval_custom.py](https://github.com/McGill-NLP/llm2vec/blob/main/experiments/mteb_eval_custom.py)

* 04/07: Added support for Gemma and Qwen-2 models, huge thanks to [@bzantium](https://github.com/bzantium) for the contribution.

Expand Down
98 changes: 98 additions & 0 deletions experiments/mteb_eval_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import argparse
from typing import Any
import mteb
import json
import torch

import numpy as np
from mteb.models.instructions import task_to_instruction
from mteb.models.text_formatting_utils import corpus_to_texts

from llm2vec import LLM2Vec

def llm2vec_instruction(instruction):
if len(instruction) > 0 and instruction[-1] != ":":
instruction = instruction.strip(".") + ":"
return instruction


class LLM2VecWrapper:
def __init__(self, model=None, task_to_instructions=None):

self.task_to_instructions = task_to_instructions
self.model = model

def encode(
self,
sentences: list[str],
*,
prompt_name: str = None,
**kwargs: Any, # noqa
) -> np.ndarray:
if prompt_name is not None:
instruction = (
self.task_to_instructions[prompt_name]
if self.task_to_instructions
and prompt_name in self.task_to_instructions
else llm2vec_instruction(task_to_instruction(prompt_name))
)
else:
instruction = ""

sentences = [[instruction, sentence] for sentence in sentences]
return self.model.encode(sentences, **kwargs)

def encode_corpus(
self,
corpus: list[dict[str, str]] | dict[str, list[str]] | list[str],
prompt_name: str = None,
**kwargs: Any,
) -> np.ndarray:
sentences = corpus_to_texts(corpus, sep=" ")
sentences = [["", sentence] for sentence in sentences]
if "request_qid" in kwargs:
kwargs.pop("request_qid")
return self.model.encode(sentences, **kwargs)

def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray:
return self.encode(queries, **kwargs)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_model_name_or_path",
type=str,
default="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
)
parser.add_argument(
"--peft_model_name_or_path",
type=str,
default="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",
)
parser.add_argument("--task_name", type=str, default="STS16")
parser.add_argument(
"--task_to_instructions_fp",
type=str,
default="test_configs/mteb/task_to_instructions.json",
)
parser.add_argument("--output_dir", type=str, default="results")

args = parser.parse_args()

task_to_instructions = None
if args.task_to_instructions_fp is not None:
with open(args.task_to_instructions_fp, "r") as f:
task_to_instructions = json.load(f)

l2v_model = LLM2Vec.from_pretrained(
args.base_model_name_or_path,
peft_model_name_or_path=args.peft_model_name_or_path,
device_map="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16,
)

model = LLM2VecWrapper(model=l2v_model, task_to_instructions=task_to_instructions)
tasks = mteb.get_tasks(tasks=[args.task_name])
evaluation = mteb.MTEB(tasks=tasks)
results = evaluation.run(model, output_folder=args.output_dir)

0 comments on commit 250292a

Please sign in to comment.