Skip to content

Commit 3dbd483

Browse files
DH-5567/adding intermediate steps as the response (#440)
* DH-5567/adding intermediate steps as the response * DH-5557add truncation * Dh-5567/reformat
1 parent 88ee8fa commit 3dbd483

12 files changed

+159
-28
lines changed

dataherald/api/fastapi.py

-1
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,6 @@ def create_prompt_and_sql_generation(
731731
return error_response(
732732
e, prompt_sql_generation_request.dict(), "sql_generation_not_created"
733733
)
734-
735734
return SQLGenerationResponse(**sql_generation.dict())
736735

737736
@override

dataherald/api/types/responses.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from dataherald.db_scanner.models.types import TableDescription
77
from dataherald.sql_database.models.types import DatabaseConnection
8-
from dataherald.types import GoldenSQL, LLMConfig
8+
from dataherald.types import GoldenSQL, IntermediateStep, LLMConfig
99

1010

1111
class BaseResponse(BaseModel):
@@ -33,6 +33,7 @@ class SQLGenerationResponse(BaseResponse):
3333
status: str
3434
completed_at: str | None
3535
llm_config: LLMConfig | None
36+
intermediate_steps: list[IntermediateStep] | None
3637
sql: str | None
3738
tokens_used: int | None
3839
confidence_score: float | None

dataherald/services/sql_generations.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ def generate_response_with_timeout(self, sql_generator, user_prompt, db_connecti
4848
user_prompt=user_prompt, database_connection=db_connection
4949
)
5050

51+
def update_the_initial_sql_generation(
52+
self, initial_sql_generation: SQLGeneration, sql_generation: SQLGeneration
53+
):
54+
initial_sql_generation.sql = sql_generation.sql
55+
initial_sql_generation.tokens_used = sql_generation.tokens_used
56+
initial_sql_generation.completed_at = datetime.now()
57+
initial_sql_generation.status = sql_generation.status
58+
initial_sql_generation.error = sql_generation.error
59+
initial_sql_generation.intermediate_steps = sql_generation.intermediate_steps
60+
return self.sql_generation_repository.update(initial_sql_generation)
61+
5162
def create(
5263
self, prompt_id: str, sql_generation_request: SQLGenerationRequest
5364
) -> SQLGeneration:
@@ -153,12 +164,9 @@ def create(
153164
)
154165
initial_sql_generation.evaluate = sql_generation_request.evaluate
155166
initial_sql_generation.confidence_score = confidence_score
156-
initial_sql_generation.sql = sql_generation.sql
157-
initial_sql_generation.tokens_used = sql_generation.tokens_used
158-
initial_sql_generation.completed_at = datetime.now()
159-
initial_sql_generation.status = sql_generation.status
160-
initial_sql_generation.error = sql_generation.error
161-
return self.sql_generation_repository.update(initial_sql_generation)
167+
return self.update_the_initial_sql_generation(
168+
initial_sql_generation, sql_generation
169+
)
162170

163171
def start_streaming(
164172
self, prompt_id: str, sql_generation_request: SQLGenerationRequest, queue: Queue

dataherald/sql_generator/__init__.py

+42-17
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from dataherald.sql_database.base import SQLDatabase, SQLInjectionError
2323
from dataherald.sql_database.models.types import DatabaseConnection
2424
from dataherald.sql_generator.create_sql_query_status import create_sql_query_status
25-
from dataherald.types import LLMConfig, Prompt, SQLGeneration
25+
from dataherald.types import IntermediateStep, LLMConfig, Prompt, SQLGeneration
2626
from dataherald.utils.strings import contains_line_breaks
2727

2828

@@ -74,20 +74,6 @@ def create_sql_query_status(
7474
) -> SQLGeneration:
7575
return create_sql_query_status(db, query, sql_generation)
7676

77-
def format_intermediate_representations(
78-
self, intermediate_representation: List[Tuple[AgentAction, str]]
79-
) -> List[str]:
80-
"""Formats the intermediate representation into a string."""
81-
formatted_intermediate_representation = []
82-
for item in intermediate_representation:
83-
formatted_intermediate_representation.append(
84-
f"Thought: '{str(item[0].log).split('Action:')[0]}'\n"
85-
f"Action: '{item[0].tool}'\n"
86-
f"Action Input: '{item[0].tool_input}'\n"
87-
f"Observation: '{item[1]}'"
88-
)
89-
return formatted_intermediate_representation
90-
9177
def format_sql_query(self, sql_query: str) -> str:
9278
comments = [
9379
match.group() for match in re.finditer(r"--.*$", sql_query, re.MULTILINE)
@@ -110,14 +96,53 @@ def extract_query_from_intermediate_steps(
11096
action = step[0]
11197
if type(action) == AgentAction and action.tool == "SqlDbQuery":
11298
if "SELECT" in self.format_sql_query(action.tool_input).upper():
113-
sql_query = self.remove_markdown(sql_query)
99+
sql_query = self.remove_markdown(action.tool_input)
114100
if sql_query == "":
115101
for step in intermediate_steps:
116102
action = step[0]
117103
if "SELECT" in action.tool_input.upper():
118-
sql_query = self.remove_markdown(sql_query)
104+
sql_query = self.remove_markdown(action.tool_input)
105+
if not sql_query.upper().strip().startswith("SELECT"):
106+
sql_query = ""
119107
return sql_query
120108

109+
def construct_intermediate_steps(
110+
self, intermediate_steps: List[Tuple[AgentAction, str]], suffix: str = ""
111+
) -> List[IntermediateStep]:
112+
"""Constructs the intermediate steps."""
113+
formatted_intermediate_steps = []
114+
for step in intermediate_steps:
115+
if step[0].tool == "SqlDbQuery":
116+
formatted_intermediate_steps.append(
117+
IntermediateStep(
118+
thought=str(step[0].log).split("Action:")[0],
119+
action=step[0].tool,
120+
action_input=step[0].tool_input,
121+
observation="QUERY RESULTS ARE NOT STORED FOR PRIVACY REASONS.",
122+
)
123+
)
124+
else:
125+
formatted_intermediate_steps.append(
126+
IntermediateStep(
127+
thought=str(step[0].log).split("Action:")[0],
128+
action=step[0].tool,
129+
action_input=step[0].tool_input,
130+
observation=self.truncate_observations(step[1]),
131+
)
132+
)
133+
formatted_intermediate_steps[0].thought = suffix.split("Thought: ")[1].split(
134+
"{agent_scratchpad}"
135+
)[0]
136+
return formatted_intermediate_steps
137+
138+
def truncate_observations(self, obervarion: str, max_length: int = 2000) -> str:
139+
"""Truncate the tool input."""
140+
return (
141+
obervarion[:max_length] + "... (truncated)"
142+
if len(obervarion) > max_length
143+
else obervarion
144+
)
145+
121146
@abstractmethod
122147
def generate_response(
123148
self,

dataherald/sql_generator/dataherald_finetuning_agent.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ async def _arun(
228228
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
229229
"""Tool for querying a SQL database."""
230230

231-
name = "ExecuteQuery"
231+
name = "SqlDbQuery"
232232
description = """
233233
Input: SQL query.
234234
Output: Result from the database or an error message if the query is incorrect.
@@ -591,6 +591,9 @@ def generate_response(
591591
response.sql = replace_unprocessable_characters(sql_query)
592592
response.tokens_used = cb.total_tokens
593593
response.completed_at = datetime.datetime.now()
594+
response.intermediate_steps = self.construct_intermediate_steps(
595+
result["intermediate_steps"], FINETUNING_AGENT_SUFFIX
596+
)
594597
return self.create_sql_query_status(
595598
self.database,
596599
response.sql,

dataherald/sql_generator/dataherald_sqlagent.py

+7
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,13 @@ def generate_response(
736736
response.sql = replace_unprocessable_characters(sql_query)
737737
response.tokens_used = cb.total_tokens
738738
response.completed_at = datetime.datetime.now()
739+
if number_of_samples > 0:
740+
suffix = SUFFIX_WITH_FEW_SHOT_SAMPLES
741+
else:
742+
suffix = SUFFIX_WITHOUT_FEW_SHOT_SAMPLES
743+
response.intermediate_steps = self.construct_intermediate_steps(
744+
result["intermediate_steps"], suffix=suffix
745+
)
739746
return self.create_sql_query_status(
740747
self.database,
741748
response.sql,

dataherald/types.py

+8
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,21 @@ class LLMConfig(BaseModel):
177177
api_base: str | None = None
178178

179179

180+
class IntermediateStep(BaseModel):
181+
thought: str
182+
action: str
183+
action_input: str
184+
observation: str
185+
186+
180187
class SQLGeneration(BaseModel):
181188
id: str | None = None
182189
prompt_id: str
183190
finetuning_id: str | None
184191
low_latency_mode: bool = False
185192
llm_config: LLMConfig | None
186193
evaluate: bool = False
194+
intermediate_steps: list[IntermediateStep] | None
187195
sql: str | None
188196
status: str = "INVALID"
189197
completed_at: datetime | None

dataherald/utils/agent_prompts.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
If SQL results has None or NULL values, handle them by adding a WHERE clause to filter them out.
122122
If SQL query doesn't follow the instructions or return incorrect results modify the SQL query to fit the instructions and fix the errors.
123123
Only make minor modifications to the SQL query, do not change the SQL query completely.
124-
You MUST always use the ExecuteQuery tool to make sure the SQL query is correct before returning it.
124+
You MUST always use the SqlDbQuery tool to make sure the SQL query is correct before returning it.
125125
126126
### Instructions from the database administrator:
127127
{admin_instructions}
@@ -134,7 +134,7 @@
134134
#
135135
Here is the plan you have to follow:
136136
1) Use the `GenerateSql` tool to generate a SQL query for the given question.
137-
2) Always Use the `ExecuteQuery` tool to execute the SQL query on the database to check if the results are correct.
137+
2) Always Use the `SqlDbQuery` tool to execute the SQL query on the database to check if the results are correct.
138138
#
139139
140140
### Instructions from the database administrator:

docs/api.create_prompt_sql_generation.rst

+16
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ HTTP 201 code response
6767
"llm_name": "gpt-4-turbo-preview",
6868
"api_base": "string"
6969
},
70+
"intermediate_steps": [
71+
{
72+
"action": "string",
73+
"action_input": "string",
74+
"observation": "string"
75+
}
76+
],
7077
"sql": "string",
7178
"tokens_used": 0,
7279
"confidence_score": 0,
@@ -113,6 +120,15 @@ HTTP 201 code response
113120
"llm_name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
114121
"api_base": "https://tt5h145hsc119q-8000.proxy.runpod.net/v1"
115122
},
123+
intermediate_steps": [
124+
{
125+
"thought": "I should Collect examples of Question/SQL pairs to check if there is a similar question among the examples.\n",
126+
"action": "FewshotExamplesRetriever",
127+
"action_input": "5",
128+
"observation": "samples ... "
129+
},
130+
...
131+
],
116132
"sql": "SELECT metric_value \nFROM renthub_median_rent \nWHERE period_type = 'monthly' \nAND geo_type = 'city' \nAND location_name = 'Miami' \nAND property_type = 'All Residential' \nAND period_end = (SELECT DATE_TRUNC('MONTH', CURRENT_DATE()) - INTERVAL '1 day')\nLIMIT 10",
117133
"tokens_used": 18115,
118134
"confidence_score": 0.95,

docs/api.create_sql_generation.rst

+16
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ HTTP 201 code response
6161
"llm_name": "gpt-4-turbo-preview",
6262
"api_base": "string"
6363
},
64+
"intermediate_steps": [
65+
{
66+
"action": "string",
67+
"action_input": "string",
68+
"observation": "string"
69+
}
70+
],
6471
"sql": "string",
6572
"tokens_used": 0,
6673
"confidence_score": 0,
@@ -102,6 +109,15 @@ HTTP 201 code response
102109
"llm_name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
103110
"api_base": "https://tt5h145hsc119q-8000.proxy.runpod.net/v1"
104111
},
112+
intermediate_steps": [
113+
{
114+
"thought": "I should Collect examples of Question/SQL pairs to check if there is a similar question among the examples.\n",
115+
"action": "FewshotExamplesRetriever",
116+
"action_input": "5",
117+
"observation": "samples ... "
118+
},
119+
...
120+
],
105121
"sql": "SELECT metric_value \nFROM renthub_median_rent \nWHERE period_type = 'monthly' \nAND geo_type = 'city' \nAND location_name = 'Miami' \nAND property_type = 'All Residential' \nAND period_end = (SELECT DATE_TRUNC('MONTH', CURRENT_DATE()) - INTERVAL '1 day')\nLIMIT 10",
106122
"tokens_used": 18115,
107123
"confidence_score": null,

docs/api.get_sql_generation.rst

+24
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ HTTP 200 code response
2828
"finetuning_id": "string",
2929
"status": "string",
3030
"completed_at": "string",
31+
"llm_config": {
32+
"llm_name": "gpt-4-turbo-preview",
33+
"api_base": "string"
34+
},
35+
"intermediate_steps": [
36+
{
37+
"action": "string",
38+
"action_input": "string",
39+
"observation": "string"
40+
}
41+
],
3142
"sql": "string",
3243
"tokens_used": 0,
3344
"confidence_score": 0,
@@ -54,6 +65,19 @@ HTTP 200 code response
5465
"finetuning_id": null,
5566
"status": "VALID",
5667
"completed_at": "2024-01-04 21:11:27.235000+00:00",
68+
"llm_config": {
69+
"llm_name": "gpt-4-turbo-preview",
70+
"api_base": null
71+
},
72+
"intermediate_steps": [
73+
{
74+
"thought": "I should Collect examples of Question/SQL pairs to check if there is a similar question among the examples.\n",
75+
"action": "FewshotExamplesRetriever",
76+
"action_input": "5",
77+
"observation": "Found 5 examples of similar questions."
78+
},
79+
...
80+
],
5781
"sql": "\nSELECT dh_zip_code, MAX(metric_value) as highest_rent -- Select the zip code and the maximum rent value\nFROM renthub_average_rent\nWHERE dh_county_name = 'Los Angeles' -- Filter for Los Angeles county\nAND period_start <= '2022-05-01' -- Filter for the period that starts on or before May 1st, 2022\nAND period_end >= '2022-05-31' -- Filter for the period that ends on or after May 31st, 2022\nGROUP BY dh_zip_code -- Group by zip code to aggregate rent values\nORDER BY highest_rent DESC -- Order by the highest rent in descending order\nLIMIT 1; -- Limit to the top result\n",
5882
"tokens_used": 12185,
5983
"confidence_score": null,

docs/api.list_sql_generations.rst

+24
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ HTTP 200 code response
3030
"finetuning_id": "string",
3131
"status": "string",
3232
"completed_at": "string",
33+
"llm_config": {
34+
"llm_name": "gpt-4-turbo-preview",
35+
"api_base": "string"
36+
},
37+
"intermediate_steps": [
38+
{
39+
"action": "string",
40+
"action_input": "string",
41+
"observation": "string"
42+
}
43+
],
3344
"sql": "string",
3445
"tokens_used": 0,
3546
"confidence_score": 0,
@@ -58,6 +69,19 @@ HTTP 200 code response
5869
"finetuning_id": null,
5970
"status": "VALID",
6071
"completed_at": "2024-01-03 18:54:55.091000+00:00",
72+
"llm_config": {
73+
"llm_name": "gpt-4-turbo-preview",
74+
"api_base": null
75+
},
76+
"intermediate_steps": [
77+
{
78+
"thought": "I should Collect examples of Question/SQL pairs to check if there is a similar question among the examples.\n",
79+
"action": "FewshotExamplesRetriever",
80+
"action_input": "5",
81+
"observation": "Found 5 examples of similar questions."
82+
},
83+
...
84+
],
6185
"sql": "\nSELECT metric_value -- Rent price\nFROM renthub_median_rent\nWHERE geo_type='city' -- Focusing on city-level data\n AND dh_state_name = 'California' -- State is California\n AND dh_place_name = 'Los Angeles' -- City is Los Angeles\n AND period_start = '2023-06-01' -- Most recent data available\nORDER BY metric_value DESC -- In case there are multiple entries, order by price descending\nLIMIT 1; -- Only need the top result\n",
6286
"tokens_used": 9491,
6387
"confidence_score": null,

0 commit comments

Comments
 (0)