Skip to content

Commit 34a7afe

Browse files
committed
refactor and add streaming functions
1 parent 6de0b9e commit 34a7afe

File tree

2 files changed

+286
-59
lines changed

2 files changed

+286
-59
lines changed

src/backend/fastapi_app/rag_advanced.py

Lines changed: 136 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
1-
import pathlib
21
from collections.abc import AsyncGenerator
3-
from typing import (
4-
Any,
5-
)
2+
from typing import Any
63

7-
from openai import AsyncAzureOpenAI, AsyncOpenAI
8-
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
4+
from openai import AsyncAzureOpenAI, AsyncOpenAI, AsyncStream
5+
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam
96
from openai_messages_token_helper import build_messages, get_token_limit
107

11-
from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
12-
from .postgres_searcher import PostgresSearcher
13-
from .query_rewriter import build_search_function, extract_search_arguments
8+
from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
9+
from fastapi_app.postgres_searcher import PostgresSearcher
10+
from fastapi_app.query_rewriter import build_search_function, extract_search_arguments
11+
from fastapi_app.rag_simple import RAGChatBase
1412

1513

16-
class AdvancedRAGChat:
14+
class AdvancedRAGChat(RAGChatBase):
1715
def __init__(
1816
self,
1917
*,
@@ -27,29 +25,21 @@ def __init__(
2725
self.chat_model = chat_model
2826
self.chat_deployment = chat_deployment
2927
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
30-
current_dir = pathlib.Path(__file__).parent
31-
self.query_prompt_template = open(current_dir / "prompts/query.txt").read()
32-
self.answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
3328

3429
async def run(
35-
self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any] = {}
36-
) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]:
37-
text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
38-
vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
39-
top = overrides.get("top", 3)
40-
41-
original_user_query = messages[-1]["content"]
42-
if not isinstance(original_user_query, str):
43-
raise ValueError("The most recent message content must be a string.")
44-
past_messages = messages[:-1]
30+
self,
31+
messages: list[ChatCompletionMessageParam],
32+
overrides: dict[str, Any] = {},
33+
) -> RetrievalResponse:
34+
chat_params = self.get_params(messages, overrides)
4535

4636
# Generate an optimized keyword search query based on the chat history and the last question
4737
query_response_token_limit = 500
4838
query_messages: list[ChatCompletionMessageParam] = build_messages(
4939
model=self.chat_model,
5040
system_prompt=self.query_prompt_template,
51-
new_user_content=original_user_query,
52-
past_messages=past_messages,
41+
new_user_content=chat_params.original_user_query,
42+
past_messages=chat_params.past_messages,
5343
max_tokens=self.chat_token_limit - query_response_token_limit, # TODO: count functions
5444
fallback_to_default=True,
5545
)
@@ -65,14 +55,14 @@ async def run(
6555
tool_choice="auto",
6656
)
6757

68-
query_text, filters = extract_search_arguments(original_user_query, chat_completion)
58+
query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion)
6959

7060
# Retrieve relevant items from the database with the GPT optimized query
7161
results = await self.searcher.search_and_embed(
7262
query_text,
73-
top=top,
74-
enable_vector_search=vector_search,
75-
enable_text_search=text_search,
63+
top=chat_params.top,
64+
enable_vector_search=chat_params.enable_vector_search,
65+
enable_text_search=chat_params.enable_text_search,
7666
filters=filters,
7767
)
7868

@@ -84,8 +74,8 @@ async def run(
8474
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
8575
model=self.chat_model,
8676
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
87-
new_user_content=original_user_query + "\n\nSources:\n" + content,
88-
past_messages=past_messages,
77+
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
78+
past_messages=chat_params.past_messages,
8979
max_tokens=self.chat_token_limit - response_token_limit,
9080
fallback_to_default=True,
9181
)
@@ -99,6 +89,7 @@ async def run(
9989
n=1,
10090
stream=False,
10191
)
92+
10293
first_choice_message = chat_completion_response.choices[0].message
10394

10495
return RetrievalResponse(
@@ -119,9 +110,9 @@ async def run(
119110
title="Search using generated search arguments",
120111
description=query_text,
121112
props={
122-
"top": top,
123-
"vector_search": vector_search,
124-
"text_search": text_search,
113+
"top": chat_params.top,
114+
"vector_search": chat_params.enable_vector_search,
115+
"text_search": chat_params.enable_text_search,
125116
"filters": filters,
126117
},
127118
),
@@ -141,3 +132,114 @@ async def run(
141132
],
142133
),
143134
)
135+
136+
async def run_stream(
137+
self,
138+
messages: list[ChatCompletionMessageParam],
139+
overrides: dict[str, Any] = {},
140+
) -> AsyncGenerator[RetrievalResponse | Message, None]:
141+
chat_params = self.get_params(messages, overrides)
142+
143+
# Generate an optimized keyword search query based on the chat history and the last question
144+
query_response_token_limit = 500
145+
query_messages: list[ChatCompletionMessageParam] = build_messages(
146+
model=self.chat_model,
147+
system_prompt=self.query_prompt_template,
148+
new_user_content=chat_params.original_user_query,
149+
past_messages=chat_params.past_messages,
150+
max_tokens=self.chat_token_limit - query_response_token_limit, # TODO: count functions
151+
fallback_to_default=True,
152+
)
153+
154+
chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create(
155+
messages=query_messages,
156+
# Azure OpenAI takes the deployment name as the model name
157+
model=self.chat_deployment if self.chat_deployment else self.chat_model,
158+
temperature=0.0, # Minimize creativity for search query generation
159+
max_tokens=query_response_token_limit, # Setting too low risks malformed JSON, too high risks performance
160+
n=1,
161+
tools=build_search_function(),
162+
tool_choice="auto",
163+
)
164+
165+
query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion)
166+
167+
# Retrieve relevant items from the database with the GPT optimized query
168+
results = await self.searcher.search_and_embed(
169+
query_text,
170+
top=chat_params.top,
171+
enable_vector_search=chat_params.enable_vector_search,
172+
enable_text_search=chat_params.enable_text_search,
173+
filters=filters,
174+
)
175+
176+
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
177+
content = "\n".join(sources_content)
178+
179+
# Generate a contextual and content specific answer using the search results and chat history
180+
response_token_limit = 1024
181+
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
182+
model=self.chat_model,
183+
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
184+
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
185+
past_messages=chat_params.past_messages,
186+
max_tokens=self.chat_token_limit - response_token_limit,
187+
fallback_to_default=True,
188+
)
189+
190+
chat_completion_async_stream: AsyncStream[
191+
ChatCompletionChunk
192+
] = await self.openai_chat_client.chat.completions.create(
193+
# Azure OpenAI takes the deployment name as the model name
194+
model=self.chat_deployment if self.chat_deployment else self.chat_model,
195+
messages=contextual_messages,
196+
temperature=overrides.get("temperature", 0.3),
197+
max_tokens=response_token_limit,
198+
n=1,
199+
stream=True,
200+
)
201+
202+
yield RetrievalResponse(
203+
message=Message(content="", role="assistant"),
204+
context=RAGContext(
205+
data_points={item.id: item.to_dict() for item in results},
206+
thoughts=[
207+
ThoughtStep(
208+
title="Prompt to generate search arguments",
209+
description=[str(message) for message in query_messages],
210+
props=(
211+
{"model": self.chat_model, "deployment": self.chat_deployment}
212+
if self.chat_deployment
213+
else {"model": self.chat_model}
214+
),
215+
),
216+
ThoughtStep(
217+
title="Search using generated search arguments",
218+
description=query_text,
219+
props={
220+
"top": chat_params.top,
221+
"vector_search": chat_params.enable_vector_search,
222+
"text_search": chat_params.enable_text_search,
223+
"filters": filters,
224+
},
225+
),
226+
ThoughtStep(
227+
title="Search results",
228+
description=[result.to_dict() for result in results],
229+
),
230+
ThoughtStep(
231+
title="Prompt to generate answer",
232+
description=[str(message) for message in contextual_messages],
233+
props=(
234+
{"model": self.chat_model, "deployment": self.chat_deployment}
235+
if self.chat_deployment
236+
else {"model": self.chat_model}
237+
),
238+
),
239+
],
240+
),
241+
)
242+
243+
async for response_chunk in chat_completion_async_stream:
244+
yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant")
245+
return

0 commit comments

Comments
 (0)