Skip to content

Commit 1205d8a

Browse files
DH-5599/adding_straming_endpoint (#430)
1 parent 61a92c9 commit 1205d8a

File tree

9 files changed

+377
-6
lines changed

9 files changed

+377
-6
lines changed

dataherald/api/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
PromptSQLGenerationNLGenerationRequest,
1313
PromptSQLGenerationRequest,
1414
SQLGenerationRequest,
15+
StreamPromptSQLGenerationRequest,
1516
UpdateMetadataRequest,
1617
)
1718
from dataherald.api.types.responses import (
@@ -265,3 +266,10 @@ def update_nl_generation(
265266
self, nl_generation_id: str, update_metadata_request: UpdateMetadataRequest
266267
) -> NLGenerationResponse:
267268
pass
269+
270+
@abstractmethod
271+
async def stream_create_prompt_and_sql_generation(
272+
self,
273+
request: StreamPromptSQLGenerationRequest,
274+
):
275+
pass

dataherald/api/fastapi.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import asyncio
12
import datetime
23
import io
4+
import json
35
import logging
46
import os
57
import time
8+
from queue import Queue
69
from typing import List
710

811
from bson.objectid import InvalidId, ObjectId
@@ -18,6 +21,7 @@
1821
PromptSQLGenerationNLGenerationRequest,
1922
PromptSQLGenerationRequest,
2023
SQLGenerationRequest,
24+
StreamPromptSQLGenerationRequest,
2125
UpdateMetadataRequest,
2226
)
2327
from dataherald.api.types.responses import (
@@ -83,7 +87,7 @@
8387
UpdateInstruction,
8488
)
8589
from dataherald.utils.encrypt import FernetEncrypt
86-
from dataherald.utils.error_codes import error_response
90+
from dataherald.utils.error_codes import error_response, stream_error_response
8791

8892
logger = logging.getLogger(__name__)
8993

@@ -884,3 +888,26 @@ def get_nl_generation(self, nl_generation_id: str) -> NLGenerationResponse:
884888
detail=f"NL Generation {nl_generation_id} not found",
885889
)
886890
return NLGenerationResponse(**nl_generations[0].dict())
891+
892+
@override
893+
async def stream_create_prompt_and_sql_generation(
894+
self,
895+
request: StreamPromptSQLGenerationRequest,
896+
):
897+
try:
898+
queue = Queue()
899+
prompt_service = PromptService(self.storage)
900+
prompt = prompt_service.create(request.prompt)
901+
sql_generation_service = SQLGenerationService(self.system, self.storage)
902+
sql_generation_service.start_streaming(prompt.id, request, queue)
903+
while True:
904+
value = queue.get()
905+
if value is None:
906+
break
907+
yield value
908+
queue.task_done()
909+
await asyncio.sleep(0.001)
910+
except Exception as e:
911+
yield json.dumps(
912+
stream_error_response(e, request.dict(), "nl_generation_not_created")
913+
)

dataherald/api/types/requests.py

+11
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,21 @@ class SQLGenerationRequest(BaseModel):
1818
metadata: dict | None
1919

2020

21+
class StreamSQLGenerationRequest(BaseModel):
22+
finetuning_id: str | None
23+
low_latency_mode: bool = False
24+
llm_config: LLMConfig | None
25+
metadata: dict | None
26+
27+
2128
class PromptSQLGenerationRequest(SQLGenerationRequest):
2229
prompt: PromptRequest
2330

2431

32+
class StreamPromptSQLGenerationRequest(StreamSQLGenerationRequest):
33+
prompt: PromptRequest
34+
35+
2536
class NLGenerationRequest(BaseModel):
2637
llm_config: LLMConfig | None
2738
max_rows: int = 100

dataherald/server/fastapi/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
PromptSQLGenerationNLGenerationRequest,
1717
PromptSQLGenerationRequest,
1818
SQLGenerationRequest,
19+
StreamPromptSQLGenerationRequest,
1920
UpdateMetadataRequest,
2021
)
2122
from dataherald.api.types.responses import (
@@ -356,6 +357,13 @@ def __init__(self, settings: Settings):
356357
tags=["Finetunings"],
357358
)
358359

360+
self.router.add_api_route(
361+
"/api/v1/stream-sql-generation",
362+
self.stream_sql_generation,
363+
methods=["POST"],
364+
tags=["Stream SQL Generation"],
365+
)
366+
359367
self.router.add_api_route(
360368
"/api/v1/heartbeat", self.heartbeat, methods=["GET"], tags=["System"]
361369
)
@@ -601,3 +609,11 @@ def update_finetuning_job(
601609
) -> Finetuning:
602610
"""Gets fine tuning jobs"""
603611
return self._api.update_finetuning_job(finetuning_id, update_metadata_request)
612+
613+
async def stream_sql_generation(
614+
self, request: StreamPromptSQLGenerationRequest
615+
) -> StreamingResponse:
616+
return StreamingResponse(
617+
self._api.stream_create_prompt_and_sql_generation(request),
618+
media_type="text/event-stream",
619+
)

dataherald/services/sql_generations.py

+63
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from concurrent.futures import ThreadPoolExecutor, TimeoutError
33
from datetime import datetime
4+
from queue import Queue
45

56
import pandas as pd
67

@@ -159,6 +160,68 @@ def create(
159160
initial_sql_generation.error = sql_generation.error
160161
return self.sql_generation_repository.update(initial_sql_generation)
161162

163+
def start_streaming(
164+
self, prompt_id: str, sql_generation_request: SQLGenerationRequest, queue: Queue
165+
):
166+
initial_sql_generation = SQLGeneration(
167+
prompt_id=prompt_id,
168+
created_at=datetime.now(),
169+
llm_config=sql_generation_request.llm_config
170+
if sql_generation_request.llm_config
171+
else LLMConfig(),
172+
metadata=sql_generation_request.metadata,
173+
)
174+
self.sql_generation_repository.insert(initial_sql_generation)
175+
prompt_repository = PromptRepository(self.storage)
176+
prompt = prompt_repository.find_by_id(prompt_id)
177+
if not prompt:
178+
self.update_error(initial_sql_generation, f"Prompt {prompt_id} not found")
179+
raise PromptNotFoundError(
180+
f"Prompt {prompt_id} not found", initial_sql_generation.id
181+
)
182+
db_connection_repository = DatabaseConnectionRepository(self.storage)
183+
db_connection = db_connection_repository.find_by_id(prompt.db_connection_id)
184+
if (
185+
sql_generation_request.finetuning_id is None
186+
or sql_generation_request.finetuning_id == ""
187+
):
188+
if sql_generation_request.low_latency_mode:
189+
raise SQLGenerationError(
190+
"Low latency mode is not supported for our old agent with no finetuning. Please specify a finetuning id.",
191+
initial_sql_generation.id,
192+
)
193+
sql_generator = DataheraldSQLAgent(
194+
self.system,
195+
sql_generation_request.llm_config
196+
if sql_generation_request.llm_config
197+
else LLMConfig(),
198+
)
199+
else:
200+
sql_generator = DataheraldFinetuningAgent(
201+
self.system,
202+
sql_generation_request.llm_config
203+
if sql_generation_request.llm_config
204+
else LLMConfig(),
205+
)
206+
sql_generator.finetuning_id = sql_generation_request.finetuning_id
207+
sql_generator.use_fintuned_model_only = (
208+
sql_generation_request.low_latency_mode
209+
)
210+
initial_sql_generation.finetuning_id = sql_generation_request.finetuning_id
211+
initial_sql_generation.low_latency_mode = (
212+
sql_generation_request.low_latency_mode
213+
)
214+
try:
215+
sql_generator.stream_response(
216+
user_prompt=prompt,
217+
database_connection=db_connection,
218+
response=initial_sql_generation,
219+
queue=queue,
220+
)
221+
except Exception as e:
222+
self.update_error(initial_sql_generation, str(e))
223+
raise SQLGenerationError(str(e), initial_sql_generation.id) from e
224+
162225
def get(self, query) -> list[SQLGeneration]:
163226
return self.sql_generation_repository.find_by(query)
164227

dataherald/sql_generator/__init__.py

+79-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
"""Base class that all sql generation classes inherit from."""
2+
import datetime
3+
import logging
24
import os
35
import re
46
from abc import ABC, abstractmethod
5-
from typing import Any, List, Tuple
7+
from queue import Queue
8+
from typing import Any, Dict, List, Tuple
69

710
import sqlparse
8-
from langchain.schema import AgentAction
11+
from langchain.agents.agent import AgentExecutor
12+
from langchain.callbacks.base import BaseCallbackHandler
13+
from langchain.schema import AgentAction, LLMResult
14+
from langchain.schema.messages import BaseMessage
15+
from langchain_community.callbacks import get_openai_callback
916

1017
from dataherald.config import Component, System
1118
from dataherald.model.chat_model import ChatModel
12-
from dataherald.sql_database.base import SQLDatabase
19+
from dataherald.repositories.sql_generations import (
20+
SQLGenerationRepository,
21+
)
22+
from dataherald.sql_database.base import SQLDatabase, SQLInjectionError
1323
from dataherald.sql_database.models.types import DatabaseConnection
1424
from dataherald.sql_generator.create_sql_query_status import create_sql_query_status
1525
from dataherald.types import LLMConfig, Prompt, SQLGeneration
@@ -20,6 +30,12 @@ class EngineTimeOutORItemLimitError(Exception):
2030
pass
2131

2232

33+
def replace_unprocessable_characters(text: str) -> str:
34+
"""Replace unprocessable characters with a space."""
35+
text = text.strip()
36+
return text.replace(r"\_", "_")
37+
38+
2339
class SQLGenerator(Component, ABC):
2440
metadata: Any
2541
llm: ChatModel | None = None
@@ -114,3 +130,63 @@ def generate_response(
114130
) -> SQLGeneration:
115131
"""Generates a response to a user question."""
116132
pass
133+
134+
def stream_agent_steps( # noqa: C901
135+
self,
136+
question: str,
137+
agent_executor: AgentExecutor,
138+
response: SQLGeneration,
139+
sql_generation_repository: SQLGenerationRepository,
140+
queue: Queue,
141+
):
142+
try:
143+
with get_openai_callback() as cb:
144+
for chunk in agent_executor.stream({"input": question}):
145+
if "actions" in chunk:
146+
for message in chunk["messages"]:
147+
queue.put(message.content + "\n")
148+
elif "steps" in chunk:
149+
for step in chunk["steps"]:
150+
queue.put(f"Observation: `{step.observation}`\n")
151+
elif "output" in chunk:
152+
queue.put(f'Final Answer: {chunk["output"]}')
153+
if "```sql" in chunk["output"]:
154+
response.sql = replace_unprocessable_characters(
155+
self.remove_markdown(chunk["output"])
156+
)
157+
else:
158+
raise ValueError()
159+
except SQLInjectionError as e:
160+
raise SQLInjectionError(e) from e
161+
except EngineTimeOutORItemLimitError as e:
162+
raise EngineTimeOutORItemLimitError(e) from e
163+
except Exception as e:
164+
response.sql = ("",)
165+
response.status = ("INVALID",)
166+
response.error = (str(e),)
167+
finally:
168+
queue.put(None)
169+
response.tokens_used = cb.total_tokens
170+
response.completed_at = datetime.datetime.now()
171+
if not response.error:
172+
if response.sql:
173+
response = self.create_sql_query_status(
174+
self.database,
175+
response.sql,
176+
response,
177+
)
178+
else:
179+
response.status = "INVALID"
180+
response.error = "No SQL query generated"
181+
sql_generation_repository.update(response)
182+
183+
@abstractmethod
184+
def stream_response(
185+
self,
186+
user_prompt: Prompt,
187+
database_connection: DatabaseConnection,
188+
response: SQLGeneration,
189+
queue: Queue,
190+
):
191+
"""Streams a response to a user question."""
192+
pass

0 commit comments

Comments
 (0)