Skip to content

Commit 30c03a6

Browse files
committed
add framework for adding new model evaluators
1 parent 15ff884 commit 30c03a6

File tree

7 files changed

+195
-45
lines changed

7 files changed

+195
-45
lines changed

prompttools/utils/autoeval.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
import os
98
from typing import Dict
9+
<<<<<<< HEAD
1010
import openai
1111
import pandas.core.series
12+
=======
13+
>>>>>>> 862ca6e (add framework for adding new model evaluators)
1214
import jinja2
13-
from .error import PromptToolsUtilityError
15+
16+
from .model_evaluators.EvaluatorUtils import get_evaluator_for_model
1417

1518
EVALUATION_SYSTEM_PROMPT = """
1619
Determine whether or not the response is following directions.
@@ -21,18 +24,14 @@
2124
EVALUATION_USER_TEMPLATE = """
2225
PROMPT: {{prompt}}
2326
RESPONSE: {{response}}
24-
ANSWER:
2527
"""
2628

2729

28-
def _get_messages(prompt: str, response: str):
30+
def _get_user_prompt(prompt: str, response: str):
2931
environment = jinja2.Environment()
3032
template = environment.from_string(EVALUATION_USER_TEMPLATE)
31-
user_message = template.render({"prompt": prompt, "response": response})
32-
return [
33-
{"role": "system", "content": EVALUATION_SYSTEM_PROMPT},
34-
{"role": "user", "content": user_message},
35-
]
33+
user_prompt = template.render({"prompt": prompt, "response": response})
34+
return user_prompt
3635

3736

3837
def compute(prompt: str, response: str, model: str = "gpt-4") -> float:
@@ -46,10 +45,8 @@ def compute(prompt: str, response: str, model: str = "gpt-4") -> float:
4645
model (str): The OpenAI chat model to use for generating an expected response.
4746
Defaults to GPT-4.
4847
"""
49-
if not os.environ["OPENAI_API_KEY"]:
50-
raise PromptToolsUtilityError
51-
evaluation = openai.ChatCompletion.create(model=model, messages=_get_messages(prompt, response))
52-
return 1.0 if "RIGHT" in evaluation["choices"][0]["message"]["content"] else 0.0
48+
response = get_evaluator_for_model(model).evaluate(model, EVALUATION_SYSTEM_PROMPT, _get_user_prompt())
49+
return 1.0 if "RIGHT" in response else 0.0
5350

5451

5552
def evaluate(prompt: str, response: str, _metadata: Dict) -> float:

prompttools/utils/autoeval_from_expected.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
import os
9-
import openai
108
import jinja2
9+
<<<<<<< HEAD
1110
import pandas
1211
from .error import PromptToolsUtilityError
12+
=======
13+
from .model_evaluators.EvaluatorUtils import get_evaluator_for_model
14+
>>>>>>> 862ca6e (add framework for adding new model evaluators)
1315

1416
EVALUATION_SYSTEM_PROMPT = """
1517
You are a grader evaluating responses to math questions.
@@ -25,14 +27,11 @@
2527
"""
2628

2729

28-
def _get_messages(prompt: str, expected: str, response: str):
30+
def _get_user_prompt(prompt: str, expected: str, response: str):
2931
environment = jinja2.Environment()
3032
template = environment.from_string(EVALUATION_USER_TEMPLATE)
31-
user_message = template.render({"prompt": prompt, "expected": expected, "actual": response})
32-
return [
33-
{"role": "system", "content": EVALUATION_SYSTEM_PROMPT},
34-
{"role": "user", "content": user_message},
35-
]
33+
user_prompt = template.render({"prompt": prompt, "expected": expected, "actual": response})
34+
return user_prompt
3635

3736

3837
# TODO: Should this be removed since no one is using it?
@@ -47,9 +46,9 @@ def compute(prompt: str, expected: str, response: str, model: str = "gpt-4") ->
4746
model (str): The OpenAI chat model to use for generating an expected response.
4847
Defaults to GPT-4.
4948
"""
50-
if not os.environ["OPENAI_API_KEY"]:
51-
raise PromptToolsUtilityError("Missing API key for evaluation.")
52-
evaluation = openai.ChatCompletion.create(model=model, messages=_get_messages(prompt, expected, response))
49+
evaluation = get_evaluator_for_model(model).evaluate(
50+
model, EVALUATION_SYSTEM_PROMPT, _get_user_prompt(prompt, expected, response)
51+
)
5352
return 1.0 if "RIGHT" in evaluation["choices"][0]["message"]["content"] else 0.0
5453

5554

prompttools/utils/autoeval_scoring.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,38 +5,38 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8+
<<<<<<< HEAD
89
import os
910
import pandas.core.series
11+
=======
12+
>>>>>>> 862ca6e (add framework for adding new model evaluators)
1013
import jinja2
1114

15+
from prompttools.utils.model_evaluators.EvaluatorUtils import get_evaluator_for_model
16+
1217
try:
1318
import anthropic
1419
except ImportError:
1520
anthropic = None
1621

17-
18-
AUTO_EVAL_PROMPT_TEMPLATE = """
19-
{{HUMAN_PROMPT}} Given the fact {{fact}}
20-
21-
Evaluate the following Answer on a scale from 1 - 7. Please only respond with an integer from 1 - 7 with no other text.
22+
EVALUATION_SYSTEM_PROMPT = """
23+
Given the Fact and Statement, Evaluate the statement on a scale from 1 - 7.
24+
Please only respond with an integer from 1 - 7 with no other text.
2225
Lower score means the answer is factually wrong, higher score means the answer is correct. A medium score for
23-
uncertain but not wrong.
24-
25-
Answer: {{model_answer}}
26+
uncertain but not wrong"""
2627

27-
{{AI_PROMPT}}
28-
"""
28+
USER_PROMPT = """
29+
Fact: {{fact}}
30+
Statement: {{model_answer}}"""
2931

3032

31-
def _generate_auto_eval_prompt(fact: str, model_answer: str):
33+
def _generate_user_prompt(fact: str, model_answer: str):
3234
environment = jinja2.Environment()
33-
template = environment.from_string(AUTO_EVAL_PROMPT_TEMPLATE)
35+
template = environment.from_string(USER_PROMPT)
3436
auto_eval_prompt = template.render(
3537
{
36-
"HUMAN_PROMPT": anthropic.HUMAN_PROMPT,
37-
"AI_PROMPT": anthropic.AI_PROMPT,
3838
"fact": fact,
39-
"model_answer": model_answer,
39+
"statement": model_answer,
4040
}
4141
)
4242
return auto_eval_prompt
@@ -54,13 +54,10 @@ def compute(fact: str, model_answer: str, model: str = "claude-2") -> float:
5454
model (str): The model that will be judging how close is the response from the truth.
5555
Defaults to Claude 2.
5656
"""
57-
if not os.environ["ANTHROPIC_API_KEY"]:
58-
raise RuntimeError("Missing API key for evaluation.")
59-
client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
60-
completion_response = client.completions.create(
61-
max_tokens_to_sample=100, model=model, prompt=_generate_auto_eval_prompt(fact, model_answer)
57+
response = get_evaluator_for_model(model).evaluate(
58+
model, EVALUATION_SYSTEM_PROMPT, _generate_user_prompt(fact, model_answer)
6259
)
63-
return int(completion_response.completion)
60+
return int(response)
6461

6562

6663
def autoeval_scoring(row: pandas.core.series.Series, expected: str, response_column_name: str = "response") -> float:
@@ -73,9 +70,13 @@ def autoeval_scoring(row: pandas.core.series.Series, expected: str, response_col
7370
expected (str): the expected response
7471
response_column_name (str): name of the column that contains the model's response, defaults to ``"response"``
7572
"""
73+
<<<<<<< HEAD
7674
if anthropic is None:
7775
raise ModuleNotFoundError(
7876
"Package `anthropic` is required to be installed to use this experiment."
7977
"Please use `pip install anthropic` to install the package"
8078
)
8179
return compute(fact=expected, model_answer=row[response_column_name])
80+
=======
81+
return compute(fact=expected, model_answer=response)
82+
>>>>>>> 862ca6e (add framework for adding new model evaluators)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Hegel AI, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code's license can be found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from overrides import override
9+
from .ModelEvaluator import ModelEvaluator
10+
import jinja2
11+
import anthropic
12+
import os
13+
14+
ANTHROPIC_API_AUTOEVAL_TEMPLATE = """
15+
{{HUMAN_PROMPT}} {{EVALUATION_SYSTEM_PROMPT}}
16+
{{USER_MESSAGE}} {{AI_PROMPT}}
17+
"""
18+
19+
20+
class AnthropicEvaluator(ModelEvaluator):
21+
def __init__(self) -> None:
22+
self.client = None
23+
self.supported_models = ["claude-1", "claude-2"]
24+
25+
def supports_model(self, model: str):
26+
return model in self.supports_model(model)
27+
28+
@override
29+
def evaluate(self, model: str, evaluation_system_prompt: str, user_message: str):
30+
if anthropic is None:
31+
raise ModuleNotFoundError(
32+
"Package `anthropic` is required to be installed to use this experiment."
33+
" Please use `pip install anthropic` to install the package"
34+
)
35+
36+
if not os.environ["ANTHROPIC_API_KEY"]:
37+
raise RuntimeError("Missing API key for evaluation.")
38+
39+
if not self.client:
40+
self.client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
41+
42+
environment = jinja2.Environment()
43+
template = environment.from_string(ANTHROPIC_API_AUTOEVAL_TEMPLATE)
44+
eval_prompt = template.render(
45+
{
46+
"HUMAN_PROMPT": anthropic.HUMAN_PROMPT,
47+
"EVALUATION_SYSTEM_PROMPT": evaluation_system_prompt,
48+
"USER_MESSAGE": user_message,
49+
"AI_PROMPT": anthropic.AI_PROMPT,
50+
}
51+
)
52+
53+
response = self.client.completions.create(max_tokens_to_sample=100, model=model, prompt=eval_prompt)
54+
55+
return response.completion
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Hegel AI, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code's license can be found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .ModelEvaluator import ModelEvaluator
8+
from .GptEvaluator import GptEvaluator
9+
from .AnthropicEvaluator import AnthropicEvaluator
10+
11+
Evaluators = [GptEvaluator(), AnthropicEvaluator()]
12+
13+
14+
def get_evaluator_for_model(model: str) -> ModelEvaluator:
15+
for evaluator in Evaluators:
16+
if evaluator.supports_model(model):
17+
return evaluator
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) Hegel AI, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code's license can be found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import jinja2
9+
from overrides import override
10+
11+
from prompttools.utils.error import PromptToolsUtilityError
12+
from .ModelEvaluator import ModelEvaluator
13+
import openai
14+
15+
OPENAI_EVAL_PROMPT = """
16+
{{USER_MESSAGE}}
17+
ANSWER:
18+
"""
19+
20+
21+
class GptEvaluator(ModelEvaluator):
22+
def __init__(self) -> None:
23+
# source: https://platform.openai.com/docs/models/model-endpoint-compatibility
24+
self.supported_models = [
25+
"gpt-4",
26+
"gpt-4-0613",
27+
"gpt-4-32k",
28+
"gpt-4-32k-0613",
29+
"gpt-3.5-turbo",
30+
"gpt-3.5-turbo-0613",
31+
"gpt-3.5-turbo-16k",
32+
"gpt-3.5-turbo-16k-0613",
33+
]
34+
35+
@override
36+
def supports_model(self, model) -> bool:
37+
return model in self.supported_models
38+
39+
@override
40+
def evaluate(self, model: str, evaluation_system_prompt: str, user_message: str):
41+
if not os.environ["OPENAI_API_KEY"]:
42+
raise PromptToolsUtilityError
43+
44+
response = openai.ChatCompletion.create(
45+
model=model, messages=self.get_messages(evaluation_system_prompt, user_message)
46+
)
47+
return response["choices"][0]["message"]["content"]
48+
49+
def get_messages(self, evaluation_system_prompt, user_message) -> list:
50+
environment = jinja2.Environment()
51+
template = environment.from_string(OPENAI_EVAL_PROMPT)
52+
eval_prompt = template.render(
53+
{
54+
"USER_MESSAGE": user_message,
55+
}
56+
)
57+
58+
messages = [
59+
{"role": "system", "content": evaluation_system_prompt},
60+
{"role": "user", "content": eval_prompt},
61+
]
62+
63+
return messages
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Hegel AI, Inc.
2+
# All rights reserved.
3+
#
4+
# This source code's license can be found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from abc import ABC, abstractmethod
9+
10+
11+
class ModelEvaluator(ABC):
12+
@abstractmethod
13+
def evaluate(self, model: str, evaluation_system_prompt: str, user_message: str):
14+
pass
15+
16+
@abstractmethod
17+
def supports_model(self, model: str):
18+
pass

0 commit comments

Comments
 (0)