Skip to content

Commit 7e11373

Browse files
committed
refactor code
1 parent 34a7afe commit 7e11373

File tree

2 files changed

+86
-92
lines changed

2 files changed

+86
-92
lines changed

src/backend/fastapi_app/rag_advanced.py

Lines changed: 41 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from openai_messages_token_helper import build_messages, get_token_limit
77

88
from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
9+
from fastapi_app.postgres_models import Item
910
from fastapi_app.postgres_searcher import PostgresSearcher
1011
from fastapi_app.query_rewriter import build_search_function, extract_search_arguments
11-
from fastapi_app.rag_simple import RAGChatBase
12+
from fastapi_app.rag_simple import ChatParams, RAGChatBase
1213

1314

1415
class AdvancedRAGChat(RAGChatBase):
@@ -26,15 +27,10 @@ def __init__(
2627
self.chat_deployment = chat_deployment
2728
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
2829

29-
async def run(
30-
self,
31-
messages: list[ChatCompletionMessageParam],
32-
overrides: dict[str, Any] = {},
33-
) -> RetrievalResponse:
34-
chat_params = self.get_params(messages, overrides)
35-
36-
# Generate an optimized keyword search query based on the chat history and the last question
37-
query_response_token_limit = 500
30+
async def generate_search_query(
31+
self, chat_params: ChatParams, query_response_token_limit: int
32+
) -> tuple[list[ChatCompletionMessageParam], Any | str | None, list]:
33+
"""Generate an optimized keyword search query based on the chat history and the last question"""
3834
query_messages: list[ChatCompletionMessageParam] = build_messages(
3935
model=self.chat_model,
4036
system_prompt=self.query_prompt_template,
@@ -57,6 +53,12 @@ async def run(
5753

5854
query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion)
5955

56+
return query_messages, query_text, filters
57+
58+
async def retreive_and_build_context(
59+
self, chat_params: ChatParams, query_text: str | Any | None, filters: list
60+
) -> tuple[list[ChatCompletionMessageParam], list[Item]]:
61+
"""Retrieve relevant items from the database and build a context for the chat model."""
6062
# Retrieve relevant items from the database with the GPT optimized query
6163
results = await self.searcher.search_and_embed(
6264
query_text,
@@ -70,22 +72,40 @@ async def run(
7072
content = "\n".join(sources_content)
7173

7274
# Generate a contextual and content specific answer using the search results and chat history
73-
response_token_limit = 1024
7475
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
7576
model=self.chat_model,
76-
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
77+
system_prompt=chat_params.prompt_template,
7778
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
7879
past_messages=chat_params.past_messages,
79-
max_tokens=self.chat_token_limit - response_token_limit,
80+
max_tokens=self.chat_token_limit - chat_params.response_token_limit,
8081
fallback_to_default=True,
8182
)
83+
return contextual_messages, results
84+
85+
async def run(
86+
self,
87+
messages: list[ChatCompletionMessageParam],
88+
overrides: dict[str, Any] = {},
89+
) -> RetrievalResponse:
90+
chat_params = self.get_params(messages, overrides)
91+
92+
# Generate an optimized keyword search query based on the chat history and the last question
93+
query_messages, query_text, filters = await self.generate_search_query(
94+
chat_params=chat_params, query_response_token_limit=500
95+
)
96+
97+
# Retrieve relevant items from the database with the GPT optimized query
98+
# Generate a contextual and content specific answer using the search results and chat history
99+
contextual_messages, results = await self.retreive_and_build_context(
100+
chat_params=chat_params, query_text=query_text, filters=filters
101+
)
82102

83103
chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create(
84104
# Azure OpenAI takes the deployment name as the model name
85105
model=self.chat_deployment if self.chat_deployment else self.chat_model,
86106
messages=contextual_messages,
87-
temperature=overrides.get("temperature", 0.3),
88-
max_tokens=response_token_limit,
107+
temperature=chat_params.temperature,
108+
max_tokens=chat_params.response_token_limit,
89109
n=1,
90110
stream=False,
91111
)
@@ -141,50 +161,14 @@ async def run_stream(
141161
chat_params = self.get_params(messages, overrides)
142162

143163
# Generate an optimized keyword search query based on the chat history and the last question
144-
query_response_token_limit = 500
145-
query_messages: list[ChatCompletionMessageParam] = build_messages(
146-
model=self.chat_model,
147-
system_prompt=self.query_prompt_template,
148-
new_user_content=chat_params.original_user_query,
149-
past_messages=chat_params.past_messages,
150-
max_tokens=self.chat_token_limit - query_response_token_limit, # TODO: count functions
151-
fallback_to_default=True,
164+
query_messages, query_text, filters = await self.generate_search_query(
165+
chat_params=chat_params, query_response_token_limit=500
152166
)
153167

154-
chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create(
155-
messages=query_messages,
156-
# Azure OpenAI takes the deployment name as the model name
157-
model=self.chat_deployment if self.chat_deployment else self.chat_model,
158-
temperature=0.0, # Minimize creativity for search query generation
159-
max_tokens=query_response_token_limit, # Setting too low risks malformed JSON, too high risks performance
160-
n=1,
161-
tools=build_search_function(),
162-
tool_choice="auto",
163-
)
164-
165-
query_text, filters = extract_search_arguments(chat_params.original_user_query, chat_completion)
166-
167168
# Retrieve relevant items from the database with the GPT optimized query
168-
results = await self.searcher.search_and_embed(
169-
query_text,
170-
top=chat_params.top,
171-
enable_vector_search=chat_params.enable_vector_search,
172-
enable_text_search=chat_params.enable_text_search,
173-
filters=filters,
174-
)
175-
176-
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
177-
content = "\n".join(sources_content)
178-
179169
# Generate a contextual and content specific answer using the search results and chat history
180-
response_token_limit = 1024
181-
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
182-
model=self.chat_model,
183-
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
184-
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
185-
past_messages=chat_params.past_messages,
186-
max_tokens=self.chat_token_limit - response_token_limit,
187-
fallback_to_default=True,
170+
contextual_messages, results = await self.retreive_and_build_context(
171+
chat_params=chat_params, query_text=query_text, filters=filters
188172
)
189173

190174
chat_completion_async_stream: AsyncStream[
@@ -193,8 +177,8 @@ async def run_stream(
193177
# Azure OpenAI takes the deployment name as the model name
194178
model=self.chat_deployment if self.chat_deployment else self.chat_model,
195179
messages=contextual_messages,
196-
temperature=overrides.get("temperature", 0.3),
197-
max_tokens=response_token_limit,
180+
temperature=chat_params.temperature,
181+
max_tokens=chat_params.response_token_limit,
198182
n=1,
199183
stream=True,
200184
)

src/backend/fastapi_app/rag_simple.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@
99
from pydantic import BaseModel
1010

1111
from fastapi_app.api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
12+
from fastapi_app.postgres_models import Item
1213
from fastapi_app.postgres_searcher import PostgresSearcher
1314

1415

1516
class ChatParams(BaseModel):
16-
top: int
17-
temperature: float
17+
top: int = 3
18+
temperature: float = 0.3
19+
response_token_limit: int = 1024
1820
enable_text_search: bool
1921
enable_vector_search: bool
2022
original_user_query: str
2123
past_messages: list[ChatCompletionMessageParam]
24+
prompt_template: str
2225

2326

2427
class RAGChatBase(ABC):
@@ -27,17 +30,24 @@ class RAGChatBase(ABC):
2730
answer_prompt_template = open(current_dir / "prompts/answer.txt").read()
2831

2932
def get_params(self, messages: list[ChatCompletionMessageParam], overrides: dict[str, Any]) -> ChatParams:
30-
enable_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
31-
enable_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
3233
top: int = overrides.get("top", 3)
3334
temperature: float = overrides.get("temperature", 0.3)
35+
response_token_limit = 1024
36+
prompt_template = overrides.get("prompt_template") or self.answer_prompt_template
37+
38+
enable_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
39+
enable_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
40+
3441
original_user_query = messages[-1]["content"]
3542
if not isinstance(original_user_query, str):
3643
raise ValueError("The most recent message content must be a string.")
3744
past_messages = messages[:-1]
45+
3846
return ChatParams(
3947
top=top,
4048
temperature=temperature,
49+
response_token_limit=response_token_limit,
50+
prompt_template=prompt_template,
4151
enable_text_search=enable_text_search,
4252
enable_vector_search=enable_vector_search,
4353
original_user_query=original_user_query,
@@ -52,6 +62,15 @@ async def run(
5262
) -> RetrievalResponse:
5363
raise NotImplementedError
5464

65+
@abstractmethod
66+
async def retreive_and_build_context(
67+
self,
68+
chat_params: ChatParams,
69+
*args,
70+
**kwargs,
71+
) -> tuple[list[ChatCompletionMessageParam], list[Item]]:
72+
raise NotImplementedError
73+
5574
@abstractmethod
5675
async def run_stream(
5776
self,
@@ -78,12 +97,10 @@ def __init__(
7897
self.chat_deployment = chat_deployment
7998
self.chat_token_limit = get_token_limit(chat_model, default_to_minimum=True)
8099

81-
async def run(
82-
self,
83-
messages: list[ChatCompletionMessageParam],
84-
overrides: dict[str, Any] = {},
85-
) -> RetrievalResponse:
86-
chat_params = self.get_params(messages, overrides)
100+
async def retreive_and_build_context(
101+
self, chat_params: ChatParams
102+
) -> tuple[list[ChatCompletionMessageParam], list[Item]]:
103+
"""Retrieve relevant items from the database and build a context for the chat model."""
87104

88105
# Retrieve relevant items from the database
89106
results = await self.searcher.search_and_embed(
@@ -97,22 +114,33 @@ async def run(
97114
content = "\n".join(sources_content)
98115

99116
# Generate a contextual and content specific answer using the search results and chat history
100-
response_token_limit = 1024
101117
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
102118
model=self.chat_model,
103-
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
119+
system_prompt=chat_params.prompt_template,
104120
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
105121
past_messages=chat_params.past_messages,
106-
max_tokens=self.chat_token_limit - response_token_limit,
122+
max_tokens=self.chat_token_limit - chat_params.response_token_limit,
107123
fallback_to_default=True,
108124
)
125+
return contextual_messages, results
126+
127+
async def run(
128+
self,
129+
messages: list[ChatCompletionMessageParam],
130+
overrides: dict[str, Any] = {},
131+
) -> RetrievalResponse:
132+
chat_params = self.get_params(messages, overrides)
133+
134+
# Retrieve relevant items from the database
135+
# Generate a contextual and content specific answer using the search results and chat history
136+
contextual_messages, results = await self.retreive_and_build_context(chat_params=chat_params)
109137

110138
chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create(
111139
# Azure OpenAI takes the deployment name as the model name
112140
model=self.chat_deployment if self.chat_deployment else self.chat_model,
113141
messages=contextual_messages,
114142
temperature=chat_params.temperature,
115-
max_tokens=response_token_limit,
143+
max_tokens=chat_params.response_token_limit,
116144
n=1,
117145
stream=False,
118146
)
@@ -158,35 +186,17 @@ async def run_stream(
158186
chat_params = self.get_params(messages, overrides)
159187

160188
# Retrieve relevant items from the database
161-
results = await self.searcher.search_and_embed(
162-
chat_params.original_user_query,
163-
top=chat_params.top,
164-
enable_vector_search=chat_params.enable_vector_search,
165-
enable_text_search=chat_params.enable_text_search,
166-
)
167-
168-
sources_content = [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
169-
content = "\n".join(sources_content)
170-
171189
# Generate a contextual and content specific answer using the search results and chat history
172-
response_token_limit = 1024
173-
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
174-
model=self.chat_model,
175-
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
176-
new_user_content=chat_params.original_user_query + "\n\nSources:\n" + content,
177-
past_messages=chat_params.past_messages,
178-
max_tokens=self.chat_token_limit - response_token_limit,
179-
fallback_to_default=True,
180-
)
190+
contextual_messages, results = await self.retreive_and_build_context(chat_params=chat_params)
181191

182192
chat_completion_async_stream: AsyncStream[
183193
ChatCompletionChunk
184194
] = await self.openai_chat_client.chat.completions.create(
185195
# Azure OpenAI takes the deployment name as the model name
186196
model=self.chat_deployment if self.chat_deployment else self.chat_model,
187197
messages=contextual_messages,
188-
temperature=overrides.get("temperature", 0.3),
189-
max_tokens=response_token_limit,
198+
temperature=chat_params.temperature,
199+
max_tokens=chat_params.response_token_limit,
190200
n=1,
191201
stream=True,
192202
)

0 commit comments

Comments
 (0)