|
| 1 | +import logging |
| 2 | + |
| 3 | +from overrides import override |
| 4 | + |
| 5 | +from dataherald.config import System |
| 6 | +from dataherald.eval import Evaluation, Evaluator |
| 7 | +from dataherald.sql_database.models.types import DatabaseConnection |
| 8 | +from dataherald.types import Prompt, SQLGeneration |
| 9 | + |
| 10 | +logger = logging.getLogger(__name__) |
| 11 | +MAX_QUANTILE = 100 |
| 12 | + |
| 13 | + |
| 14 | +class LogProbEvaluator(Evaluator): |
| 15 | + def __init__(self, system: System): |
| 16 | + super().__init__(system) |
| 17 | + self.system = system |
| 18 | + |
| 19 | + def extract_query_probs(self, tokens, probs): |
| 20 | + """Extract the probabilities for each token in the query.""" |
| 21 | + query_probs = [] |
| 22 | + query_found = False |
| 23 | + for token, prob in zip(tokens, probs, strict=False): |
| 24 | + if "```" in token or "`" in token: |
| 25 | + query_found = True |
| 26 | + if query_found: |
| 27 | + query_probs.append((token, prob)) |
| 28 | + return query_probs |
| 29 | + |
| 30 | + @override |
| 31 | + def evaluate( |
| 32 | + self, |
| 33 | + user_prompt: Prompt, |
| 34 | + sql_generation: SQLGeneration, |
| 35 | + database_connection: DatabaseConnection, # noqa: ARG002 |
| 36 | + ) -> Evaluation: |
| 37 | + logger.info( |
| 38 | + f"(LogProb evaluator) Generating score for the question/sql pair: {str(user_prompt.text)}/ {str(sql_generation.sql)}" |
| 39 | + ) |
| 40 | + if sql_generation.status == "INVALID": |
| 41 | + logger.info( |
| 42 | + f"(LogProb evaluator) SQL query: {sql_generation.sql} is not valid. Returning score 0" |
| 43 | + ) |
| 44 | + return Evaluation( |
| 45 | + question_id=user_prompt.id, answer_id=sql_generation.id, score=0.0 |
| 46 | + ) |
| 47 | + for i in range(len(sql_generation.tokens) - 1, -1, -1): |
| 48 | + query_probs = self.extract_query_probs( |
| 49 | + sql_generation.tokens[i], sql_generation.probs[i] |
| 50 | + ) |
| 51 | + if query_probs: |
| 52 | + break |
| 53 | + if not query_probs: |
| 54 | + return Evaluation( |
| 55 | + question_id=user_prompt.id, answer_id=sql_generation.id, score=0.0 |
| 56 | + ) |
| 57 | + probabilities = sorted([prob for token, prob in query_probs]) |
| 58 | + tokens = [token for token, prob in query_probs] |
| 59 | + logger.info( |
| 60 | + f"(LogProb evaluator) Found {len(query_probs)} query tokens {tokens} in {i} step with probabilities." |
| 61 | + ) |
| 62 | + total_probs = len(probabilities) |
| 63 | + if sql_generation.evaluation_quantile > MAX_QUANTILE: |
| 64 | + raise ValueError( |
| 65 | + f"Evaluation quantile should be between 0 and 100. Got {sql_generation.evaluation_quantile}" |
| 66 | + ) |
| 67 | + index = int( |
| 68 | + round(((sql_generation.evaluation_quantile / 100) * (total_probs - 1)), 0) |
| 69 | + ) |
| 70 | + return Evaluation( |
| 71 | + question_id=user_prompt.id, |
| 72 | + answer_id=sql_generation.id, |
| 73 | + score=probabilities[index], |
| 74 | + ) |
0 commit comments