Skip to content

Commit 0227789

Browse files
committed
Refactor to avoid error when streaming
1 parent ebdb48b commit 0227789

File tree

6 files changed

+141
-182
lines changed

6 files changed

+141
-182
lines changed

src/backend/fastapi_app/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def create_app(testing: bool = False):
5151
else:
5252
if not testing:
5353
load_dotenv(override=True)
54-
logging.basicConfig(level=logging.INFO)
54+
logging.basicConfig(level=logging.DEBUG)
5555
# Turn off particularly noisy INFO level logs from Azure Core SDK:
5656
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)
5757

src/backend/fastapi_app/api_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,12 @@ class ItemPublic(BaseModel):
7474

7575
class ItemWithDistance(ItemPublic):
7676
distance: float
77+
78+
79+
class ChatParams(ChatRequestOverrides):
80+
prompt_template: str
81+
response_token_limit: int = 1024
82+
enable_text_search: bool
83+
enable_vector_search: bool
84+
original_user_query: str
85+
past_messages: list[ChatCompletionMessageParam]

src/backend/fastapi_app/rag_advanced.py

Lines changed: 50 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from fastapi_app.api_models import (
99
AIChatRoles,
10-
ChatRequestOverrides,
1110
Message,
1211
RAGContext,
1312
RetrievalResponse,
@@ -63,10 +62,15 @@ async def generate_search_query(
6362

6463
return query_messages, query_text, filters
6564

66-
async def retrieve_and_build_context(
67-
self, chat_params: ChatParams, query_text: str | Any | None, filters: list
68-
) -> tuple[list[ChatCompletionMessageParam], list[Item]]:
69-
"""Retrieve relevant items from the database and build a context for the chat model."""
65+
async def prepare_context(
66+
self, chat_params: ChatParams
67+
) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]:
68+
query_messages, query_text, filters = await self.generate_search_query(
69+
original_user_query=chat_params.original_user_query,
70+
past_messages=chat_params.past_messages,
71+
query_response_token_limit=500,
72+
)
73+
7074
# Retrieve relevant items from the database with the GPT optimized query
7175
results = await self.searcher.search_and_embed(
7276
query_text,
@@ -88,28 +92,41 @@ async def retrieve_and_build_context(
8892
max_tokens=self.chat_token_limit - chat_params.response_token_limit,
8993
fallback_to_default=True,
9094
)
91-
return contextual_messages, results
9295

93-
async def run(
96+
thoughts = [
97+
ThoughtStep(
98+
title="Prompt to generate search arguments",
99+
description=[str(message) for message in query_messages],
100+
props=(
101+
{"model": self.chat_model, "deployment": self.chat_deployment}
102+
if self.chat_deployment
103+
else {"model": self.chat_model}
104+
),
105+
),
106+
ThoughtStep(
107+
title="Search using generated search arguments",
108+
description=query_text,
109+
props={
110+
"top": chat_params.top,
111+
"vector_search": chat_params.enable_vector_search,
112+
"text_search": chat_params.enable_text_search,
113+
"filters": filters,
114+
},
115+
),
116+
ThoughtStep(
117+
title="Search results",
118+
description=[result.to_dict() for result in results],
119+
),
120+
]
121+
return contextual_messages, results, thoughts
122+
123+
async def answer(
94124
self,
95-
messages: list[ChatCompletionMessageParam],
96-
overrides: ChatRequestOverrides,
125+
chat_params: ChatParams,
126+
contextual_messages: list[ChatCompletionMessageParam],
127+
results: list[Item],
128+
earlier_thoughts: list[ThoughtStep],
97129
) -> RetrievalResponse:
98-
chat_params = self.get_params(messages, overrides)
99-
100-
# Generate an optimized keyword search query based on the chat history and the last question
101-
query_messages, query_text, filters = await self.generate_search_query(
102-
original_user_query=chat_params.original_user_query,
103-
past_messages=chat_params.past_messages,
104-
query_response_token_limit=500,
105-
)
106-
107-
# Retrieve relevant items from the database with the GPT optimized query
108-
# Generate a contextual and content specific answer using the search results and chat history
109-
contextual_messages, results = await self.retrieve_and_build_context(
110-
chat_params=chat_params, query_text=query_text, filters=filters
111-
)
112-
113130
chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create(
114131
# Azure OpenAI takes the deployment name as the model name
115132
model=self.chat_deployment if self.chat_deployment else self.chat_model,
@@ -126,30 +143,8 @@ async def run(
126143
),
127144
context=RAGContext(
128145
data_points={item.id: item.to_dict() for item in results},
129-
thoughts=[
130-
ThoughtStep(
131-
title="Prompt to generate search arguments",
132-
description=[str(message) for message in query_messages],
133-
props=(
134-
{"model": self.chat_model, "deployment": self.chat_deployment}
135-
if self.chat_deployment
136-
else {"model": self.chat_model}
137-
),
138-
),
139-
ThoughtStep(
140-
title="Search using generated search arguments",
141-
description=query_text,
142-
props={
143-
"top": chat_params.top,
144-
"vector_search": chat_params.enable_vector_search,
145-
"text_search": chat_params.enable_text_search,
146-
"filters": filters,
147-
},
148-
),
149-
ThoughtStep(
150-
title="Search results",
151-
description=[result.to_dict() for result in results],
152-
),
146+
thoughts=earlier_thoughts
147+
+ [
153148
ThoughtStep(
154149
title="Prompt to generate answer",
155150
description=[str(message) for message in contextual_messages],
@@ -163,23 +158,13 @@ async def run(
163158
),
164159
)
165160

166-
async def run_stream(
161+
async def answer_stream(
167162
self,
168-
messages: list[ChatCompletionMessageParam],
169-
overrides: ChatRequestOverrides,
163+
chat_params: ChatParams,
164+
contextual_messages: list[ChatCompletionMessageParam],
165+
results: list[Item],
166+
earlier_thoughts: list[ThoughtStep],
170167
) -> AsyncGenerator[RetrievalResponseDelta, None]:
171-
chat_params = self.get_params(messages, overrides)
172-
173-
query_messages, query_text, filters = await self.generate_search_query(
174-
original_user_query=chat_params.original_user_query,
175-
past_messages=chat_params.past_messages,
176-
query_response_token_limit=500,
177-
)
178-
179-
contextual_messages, results = await self.retrieve_and_build_context(
180-
chat_params=chat_params, query_text=query_text, filters=filters
181-
)
182-
183168
chat_completion_async_stream: AsyncStream[
184169
ChatCompletionChunk
185170
] = await self.openai_chat_client.chat.completions.create(
@@ -192,38 +177,11 @@ async def run_stream(
192177
stream=True,
193178
)
194179

195-
# Forcefully close the database session before yielding the response
196-
# Yielding keeps the connection open while streaming the response until the end
197-
# The connection closes when it returns back to the context manger in the dependencies
198-
await self.searcher.db_session.close()
199-
200180
yield RetrievalResponseDelta(
201181
context=RAGContext(
202182
data_points={item.id: item.to_dict() for item in results},
203-
thoughts=[
204-
ThoughtStep(
205-
title="Prompt to generate search arguments",
206-
description=[str(message) for message in query_messages],
207-
props=(
208-
{"model": self.chat_model, "deployment": self.chat_deployment}
209-
if self.chat_deployment
210-
else {"model": self.chat_model}
211-
),
212-
),
213-
ThoughtStep(
214-
title="Search using generated search arguments",
215-
description=query_text,
216-
props={
217-
"top": chat_params.top,
218-
"vector_search": chat_params.enable_vector_search,
219-
"text_search": chat_params.enable_text_search,
220-
"filters": filters,
221-
},
222-
),
223-
ThoughtStep(
224-
title="Search results",
225-
description=[result.to_dict() for result in results],
226-
),
183+
thoughts=earlier_thoughts
184+
+ [
227185
ThoughtStep(
228186
title="Prompt to generate answer",
229187
description=[str(message) for message in contextual_messages],

src/backend/fastapi_app/rag_base.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,16 @@
44

55
from openai.types.chat import ChatCompletionMessageParam
66

7-
from fastapi_app.api_models import ChatRequestOverrides, RetrievalResponse, RetrievalResponseDelta
7+
from fastapi_app.api_models import (
8+
ChatParams,
9+
ChatRequestOverrides,
10+
RetrievalResponse,
11+
RetrievalResponseDelta,
12+
ThoughtStep,
13+
)
814
from fastapi_app.postgres_models import Item
915

1016

11-
class ChatParams(ChatRequestOverrides):
12-
prompt_template: str
13-
response_token_limit: int = 1024
14-
enable_text_search: bool
15-
enable_vector_search: bool
16-
original_user_query: str
17-
past_messages: list[ChatCompletionMessageParam]
18-
19-
2017
class RAGChatBase(ABC):
2118
current_dir = pathlib.Path(__file__).parent
2219
query_prompt_template = open(current_dir / "prompts/query.txt").read()
@@ -48,27 +45,28 @@ def get_params(self, messages: list[ChatCompletionMessageParam], overrides: Chat
4845
)
4946

5047
@abstractmethod
51-
async def retrieve_and_build_context(
52-
self,
53-
chat_params: ChatParams,
54-
*args,
55-
**kwargs,
56-
) -> tuple[list[ChatCompletionMessageParam], list[Item]]:
48+
async def prepare_context(
49+
self, chat_params: ChatParams
50+
) -> tuple[list[ChatCompletionMessageParam], list[Item], list[ThoughtStep]]:
5751
raise NotImplementedError
5852

5953
@abstractmethod
60-
async def run(
54+
async def answer(
6155
self,
62-
messages: list[ChatCompletionMessageParam],
63-
overrides: ChatRequestOverrides,
56+
chat_params: ChatParams,
57+
contextual_messages: list[ChatCompletionMessageParam],
58+
results: list[Item],
59+
earlier_thoughts: list[ThoughtStep],
6460
) -> RetrievalResponse:
6561
raise NotImplementedError
6662

6763
@abstractmethod
68-
async def run_stream(
64+
async def answer_stream(
6965
self,
70-
messages: list[ChatCompletionMessageParam],
71-
overrides: ChatRequestOverrides,
66+
chat_params: ChatParams,
67+
contextual_messages: list[ChatCompletionMessageParam],
68+
results: list[Item],
69+
earlier_thoughts: list[ThoughtStep],
7270
) -> AsyncGenerator[RetrievalResponseDelta, None]:
7371
raise NotImplementedError
7472
if False:

0 commit comments

Comments
 (0)