1
1
"""Base class that all sql generation classes inherit from."""
2
+ import datetime
3
+ import logging
2
4
import os
3
5
import re
4
6
from abc import ABC , abstractmethod
5
- from typing import Any , List , Tuple
7
+ from queue import Queue
8
+ from typing import Any , Dict , List , Tuple
6
9
7
10
import sqlparse
8
- from langchain .schema import AgentAction
11
+ from langchain .agents .agent import AgentExecutor
12
+ from langchain .callbacks .base import BaseCallbackHandler
13
+ from langchain .schema import AgentAction , LLMResult
14
+ from langchain .schema .messages import BaseMessage
15
+ from langchain_community .callbacks import get_openai_callback
9
16
10
17
from dataherald .config import Component , System
11
18
from dataherald .model .chat_model import ChatModel
12
- from dataherald .sql_database .base import SQLDatabase
19
+ from dataherald .repositories .sql_generations import (
20
+ SQLGenerationRepository ,
21
+ )
22
+ from dataherald .sql_database .base import SQLDatabase , SQLInjectionError
13
23
from dataherald .sql_database .models .types import DatabaseConnection
14
24
from dataherald .sql_generator .create_sql_query_status import create_sql_query_status
15
25
from dataherald .types import LLMConfig , Prompt , SQLGeneration
@@ -20,6 +30,12 @@ class EngineTimeOutORItemLimitError(Exception):
20
30
pass
21
31
22
32
33
+ def replace_unprocessable_characters (text : str ) -> str :
34
+ """Replace unprocessable characters with a space."""
35
+ text = text .strip ()
36
+ return text .replace (r"\_" , "_" )
37
+
38
+
23
39
class SQLGenerator (Component , ABC ):
24
40
metadata : Any
25
41
llm : ChatModel | None = None
@@ -114,3 +130,63 @@ def generate_response(
114
130
) -> SQLGeneration :
115
131
"""Generates a response to a user question."""
116
132
pass
133
+
134
+ def stream_agent_steps ( # noqa: C901
135
+ self ,
136
+ question : str ,
137
+ agent_executor : AgentExecutor ,
138
+ response : SQLGeneration ,
139
+ sql_generation_repository : SQLGenerationRepository ,
140
+ queue : Queue ,
141
+ ):
142
+ try :
143
+ with get_openai_callback () as cb :
144
+ for chunk in agent_executor .stream ({"input" : question }):
145
+ if "actions" in chunk :
146
+ for message in chunk ["messages" ]:
147
+ queue .put (message .content + "\n " )
148
+ elif "steps" in chunk :
149
+ for step in chunk ["steps" ]:
150
+ queue .put (f"Observation: `{ step .observation } `\n " )
151
+ elif "output" in chunk :
152
+ queue .put (f'Final Answer: { chunk ["output" ]} ' )
153
+ if "```sql" in chunk ["output" ]:
154
+ response .sql = replace_unprocessable_characters (
155
+ self .remove_markdown (chunk ["output" ])
156
+ )
157
+ else :
158
+ raise ValueError ()
159
+ except SQLInjectionError as e :
160
+ raise SQLInjectionError (e ) from e
161
+ except EngineTimeOutORItemLimitError as e :
162
+ raise EngineTimeOutORItemLimitError (e ) from e
163
+ except Exception as e :
164
+ response .sql = ("" ,)
165
+ response .status = ("INVALID" ,)
166
+ response .error = (str (e ),)
167
+ finally :
168
+ queue .put (None )
169
+ response .tokens_used = cb .total_tokens
170
+ response .completed_at = datetime .datetime .now ()
171
+ if not response .error :
172
+ if response .sql :
173
+ response = self .create_sql_query_status (
174
+ self .database ,
175
+ response .sql ,
176
+ response ,
177
+ )
178
+ else :
179
+ response .status = "INVALID"
180
+ response .error = "No SQL query generated"
181
+ sql_generation_repository .update (response )
182
+
183
+ @abstractmethod
184
+ def stream_response (
185
+ self ,
186
+ user_prompt : Prompt ,
187
+ database_connection : DatabaseConnection ,
188
+ response : SQLGeneration ,
189
+ queue : Queue ,
190
+ ):
191
+ """Streams a response to a user question."""
192
+ pass
0 commit comments