Skip to content

Commit cdad61b

Browse files
DH-5725/adding the log prob evaluator
1 parent cf1a2e7 commit cdad61b

9 files changed

+164
-1
lines changed

dataherald/api/types/requests.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class SQLGenerationRequest(BaseModel):
1414
low_latency_mode: bool = False
1515
llm_config: LLMConfig | None
1616
evaluate: bool = False
17+
evaluation_quantile: int = 25
1718
sql: str | None
1819
metadata: dict | None
1920

dataherald/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Settings(BaseSettings):
2828
)
2929

3030
eval_impl: str = os.environ.get(
31-
"EVALUATOR", "dataherald.eval.simple_evaluator.SimpleEvaluator"
31+
"EVALUATOR", "dataherald.eval.logprob_evaluator.LogProbEvaluator"
3232
)
3333
db_impl: str = os.environ.get("DB", "dataherald.db.mongo.MongoDB")
3434

dataherald/eval/logprob_evaluator.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import logging
2+
3+
from overrides import override
4+
5+
from dataherald.config import System
6+
from dataherald.eval import Evaluation, Evaluator
7+
from dataherald.sql_database.models.types import DatabaseConnection
8+
from dataherald.types import Prompt, SQLGeneration
9+
10+
logger = logging.getLogger(__name__)
11+
MAX_QUANTILE = 100
12+
13+
14+
class LogProbEvaluator(Evaluator):
15+
def __init__(self, system: System):
16+
super().__init__(system)
17+
self.system = system
18+
19+
def extract_query_probs(self, tokens, probs):
20+
"""Extract the probabilities for each token in the query."""
21+
query_probs = []
22+
query_found = False
23+
for token, prob in zip(tokens, probs, strict=False):
24+
if "```" in token or "`" in token:
25+
query_found = True
26+
if query_found:
27+
query_probs.append((token, prob))
28+
return query_probs
29+
30+
@override
31+
def evaluate(
32+
self,
33+
user_prompt: Prompt,
34+
sql_generation: SQLGeneration,
35+
database_connection: DatabaseConnection, # noqa: ARG002
36+
) -> Evaluation:
37+
logger.info(
38+
f"(LogProb evaluator) Generating score for the question/sql pair: {str(user_prompt.text)}/ {str(sql_generation.sql)}"
39+
)
40+
if sql_generation.status == "INVALID":
41+
logger.info(
42+
f"(LogProb evaluator) SQL query: {sql_generation.sql} is not valid. Returning score 0"
43+
)
44+
return Evaluation(
45+
question_id=user_prompt.id, answer_id=sql_generation.id, score=0.0
46+
)
47+
for i in range(len(sql_generation.tokens) - 1, -1, -1):
48+
query_probs = self.extract_query_probs(
49+
sql_generation.tokens[i], sql_generation.probs[i]
50+
)
51+
if query_probs:
52+
break
53+
if not query_probs:
54+
return Evaluation(
55+
question_id=user_prompt.id, answer_id=sql_generation.id, score=0.0
56+
)
57+
probabilities = sorted([prob for token, prob in query_probs])
58+
tokens = [token for token, prob in query_probs]
59+
logger.info(
60+
f"(LogProb evaluator) Found {len(query_probs)} query tokens {tokens} in {i} step with probabilities."
61+
)
62+
total_probs = len(probabilities)
63+
if sql_generation.evaluation_quantile > MAX_QUANTILE:
64+
raise ValueError(
65+
f"Evaluation quantile should be between 0 and 100. Got {sql_generation.evaluation_quantile}"
66+
)
67+
index = int(
68+
round(((sql_generation.evaluation_quantile / 100) * (total_probs - 1)), 0)
69+
)
70+
return Evaluation(
71+
question_id=user_prompt.id,
72+
answer_id=sql_generation.id,
73+
score=probabilities[index],
74+
)

dataherald/model/chat_model.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def get_model(
2727
model_name=model_name,
2828
openai_api_key=api_key,
2929
openai_api_base=api_base,
30+
logprobs=True,
31+
top_logprobs=20,
3032
**kwargs
3133
)
3234
if model_family == "anthropic":

dataherald/services/sql_generations.py

+5
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def update_the_initial_sql_generation(
6161
initial_sql_generation.status = sql_generation.status
6262
initial_sql_generation.error = sql_generation.error
6363
initial_sql_generation.intermediate_steps = sql_generation.intermediate_steps
64+
initial_sql_generation.tokens = sql_generation.tokens
65+
initial_sql_generation.probs = sql_generation.probs
6466
return self.sql_generation_repository.update(initial_sql_generation)
6567

6668
def create(
@@ -173,6 +175,9 @@ def create(
173175
database_connection=db_connection,
174176
)
175177
initial_sql_generation.evaluate = sql_generation_request.evaluate
178+
initial_sql_generation.evaluation_quantile = (
179+
sql_generation_request.evaluation_quantile
180+
)
176181
initial_sql_generation.confidence_score = confidence_score
177182
return self.update_the_initial_sql_generation(
178183
initial_sql_generation, sql_generation

dataherald/sql_generator/dataherald_finetuning_agent.py

+9
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
DatabaseConnection,
4343
)
4444
from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator
45+
from dataherald.sql_generator.log_probs_callback_handler import (
46+
OpenAILogProbsCallbackHandler,
47+
)
4548
from dataherald.types import FineTuningStatus, Prompt, SQLGeneration
4649
from dataherald.utils.agent_prompts import (
4750
ERROR_PARSING_MESSAGE,
@@ -533,6 +536,7 @@ def generate_response(
533536
Response: The response to the user question.
534537
"""
535538
context_store = self.system.instance(ContextStore)
539+
log_prob_callback = OpenAILogProbsCallbackHandler()
536540
storage = self.system.instance(DB)
537541
response = SQLGeneration(
538542
prompt_id=user_prompt.id,
@@ -543,6 +547,7 @@ def generate_response(
543547
self.llm = self.model.get_model(
544548
database_connection=database_connection,
545549
temperature=0,
550+
callbacks=BaseCallbackManager([log_prob_callback]),
546551
model_name=self.llm_config.llm_name,
547552
api_base=self.llm_config.api_base,
548553
)
@@ -608,6 +613,8 @@ def generate_response(
608613
completed_at=datetime.datetime.now(),
609614
sql="",
610615
status="INVALID",
616+
tokens=log_prob_callback.tokens,
617+
probs=log_prob_callback.probs,
611618
error=str(e),
612619
)
613620
sql_query = ""
@@ -621,6 +628,8 @@ def generate_response(
621628
response.sql = replace_unprocessable_characters(sql_query)
622629
response.tokens_used = cb.total_tokens
623630
response.completed_at = datetime.datetime.now()
631+
response.tokens = log_prob_callback.tokens
632+
response.probs = log_prob_callback.probs
624633
response.intermediate_steps = self.construct_intermediate_steps(
625634
result["intermediate_steps"], FINETUNING_AGENT_SUFFIX
626635
)

dataherald/sql_generator/dataherald_sqlagent.py

+9
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
DatabaseConnection,
4444
)
4545
from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator
46+
from dataherald.sql_generator.log_probs_callback_handler import (
47+
OpenAILogProbsCallbackHandler,
48+
)
4649
from dataherald.types import Prompt, SQLGeneration
4750
from dataherald.utils.agent_prompts import (
4851
AGENT_PREFIX,
@@ -679,6 +682,7 @@ def generate_response(
679682
metadata: dict = None,
680683
) -> SQLGeneration:
681684
context_store = self.system.instance(ContextStore)
685+
log_prob_callback = OpenAILogProbsCallbackHandler()
682686
storage = self.system.instance(DB)
683687
response = SQLGeneration(
684688
prompt_id=user_prompt.id,
@@ -688,6 +692,7 @@ def generate_response(
688692
self.llm = self.model.get_model(
689693
database_connection=database_connection,
690694
temperature=0,
695+
callbacks=BaseCallbackManager([log_prob_callback]),
691696
model_name=self.llm_config.llm_name,
692697
api_base=self.llm_config.api_base,
693698
)
@@ -748,6 +753,8 @@ def generate_response(
748753
completed_at=datetime.datetime.now(),
749754
sql="",
750755
status="INVALID",
756+
tokens=log_prob_callback.tokens,
757+
probs=log_prob_callback.probs,
751758
error=str(e),
752759
)
753760
sql_query = ""
@@ -761,6 +768,8 @@ def generate_response(
761768
response.sql = replace_unprocessable_characters(sql_query)
762769
response.tokens_used = cb.total_tokens
763770
response.completed_at = datetime.datetime.now()
771+
response.tokens = log_prob_callback.tokens
772+
response.probs = log_prob_callback.probs
764773
if number_of_samples > 0:
765774
suffix = SUFFIX_WITH_FEW_SHOT_SAMPLES
766775
else:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import math
2+
from typing import Any, Dict
3+
4+
from langchain.schema import AgentFinish
5+
from langchain_core.callbacks import BaseCallbackHandler
6+
from langchain_core.outputs import LLMResult
7+
8+
9+
class OpenAILogProbsCallbackHandler(BaseCallbackHandler):
10+
"""Callback Handler that tracks OpenAI logprobs."""
11+
12+
tokens: list[list[str]]
13+
probs: list[list[float]]
14+
15+
def __init__(self) -> None:
16+
super().__init__()
17+
self.tokens = []
18+
self.probs = []
19+
20+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: # noqa: ARG002
21+
for generation in response.generations:
22+
model_ouptut = generation[0]
23+
temp_tokens = []
24+
temp_probs = []
25+
logprobs = model_ouptut.generation_info["logprobs"]
26+
if logprobs is None:
27+
continue
28+
for token in logprobs["content"]:
29+
top_token = token.get("token")
30+
top_token_prob = round(math.exp(token.get("logprob")), 3)
31+
for index, candidate in enumerate(token.get("top_logprobs")):
32+
if index == 0:
33+
continue
34+
candidate_token = candidate.get("token")
35+
candidate_prob = round(math.exp(candidate.get("logprob")), 3)
36+
if (
37+
top_token.strip().lower() in candidate_token.strip().lower()
38+
or candidate_token.strip().lower() in top_token.strip().lower()
39+
):
40+
top_token_prob += candidate_prob
41+
temp_tokens.append(top_token)
42+
temp_probs.append(top_token_prob)
43+
self.tokens.append(temp_tokens)
44+
self.probs.append(temp_probs)
45+
46+
def on_chain_end(
47+
self, outputs: Dict[str, Any], **kwargs: Any
48+
) -> Any: # noqa: ARG002
49+
"""Run when chain ends running."""
50+
pass
51+
52+
def on_tool_end(self, output: str, **kwargs: Any) -> Any: # noqa: ARG002
53+
"""Run when tool ends running."""
54+
pass
55+
56+
def on_agent_finish(
57+
self, finish: AgentFinish, **kwargs: Any
58+
) -> Any: # noqa: ARG002
59+
"""Run on agent end."""
60+
pass

dataherald/types.py

+3
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ class SQLGeneration(BaseModel):
191191
low_latency_mode: bool = False
192192
llm_config: LLMConfig | None
193193
evaluate: bool = False
194+
evaluation_quantile: int = 0
194195
intermediate_steps: list[IntermediateStep] | None
195196
sql: str | None
196197
status: str = "INVALID"
@@ -199,6 +200,8 @@ class SQLGeneration(BaseModel):
199200
confidence_score: float | None
200201
error: str | None
201202
created_at: datetime = Field(default_factory=datetime.now)
203+
tokens: list[str] | None
204+
probs: list[float] | None
202205
metadata: dict | None
203206

204207

0 commit comments

Comments
 (0)