Skip to content

Commit 3ef6920

Browse files
committed
[DH-5597] Fix sql-generation
1 parent 1205d8a commit 3ef6920

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

dataherald/sql_generator/__init__.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def remove_markdown(self, query: str) -> str:
6060
matches = re.findall(pattern, query, re.DOTALL)
6161
if matches:
6262
return matches[0].strip()
63-
return ""
63+
return query
6464

6565
@classmethod
6666
def get_upper_bound_limit(cls) -> int:
@@ -110,15 +110,14 @@ def extract_query_from_intermediate_steps(
110110
action = step[0]
111111
if type(action) == AgentAction and action.tool == "SqlDbQuery":
112112
sql_query = self.format_sql_query(action.tool_input)
113-
if "```sql" in sql_query:
113+
if "SELECT" in sql_query.upper():
114114
sql_query = self.remove_markdown(sql_query)
115115
if sql_query == "":
116116
for step in intermediate_steps:
117117
action = step[0]
118118
sql_query = action.tool_input
119-
if "```sql" in sql_query:
119+
if "SELECT" in sql_query.upper():
120120
sql_query = self.remove_markdown(sql_query)
121-
122121
return sql_query
123122

124123
@abstractmethod

dataherald/sql_generator/dataherald_sqlagent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ def generate_response(
730730
sql_query = self.remove_markdown(result["output"])
731731
else:
732732
sql_query = self.extract_query_from_intermediate_steps(
733-
result["intermediate"]
733+
result["intermediate_steps"]
734734
)
735735
logger.info(f"cost: {str(cb.total_cost)} tokens: {str(cb.total_tokens)}")
736736
response.sql = replace_unprocessable_characters(sql_query)

0 commit comments

Comments
 (0)