6
6
from openai_messages_token_helper import build_messages , get_token_limit
7
7
8
8
from fastapi_app .api_models import Message , RAGContext , RetrievalResponse , ThoughtStep
9
+ from fastapi_app .postgres_models import Item
9
10
from fastapi_app .postgres_searcher import PostgresSearcher
10
11
from fastapi_app .query_rewriter import build_search_function , extract_search_arguments
11
- from fastapi_app .rag_simple import RAGChatBase
12
+ from fastapi_app .rag_simple import ChatParams , RAGChatBase
12
13
13
14
14
15
class AdvancedRAGChat (RAGChatBase ):
@@ -26,15 +27,10 @@ def __init__(
26
27
self .chat_deployment = chat_deployment
27
28
self .chat_token_limit = get_token_limit (chat_model , default_to_minimum = True )
28
29
29
- async def run (
30
- self ,
31
- messages : list [ChatCompletionMessageParam ],
32
- overrides : dict [str , Any ] = {},
33
- ) -> RetrievalResponse :
34
- chat_params = self .get_params (messages , overrides )
35
-
36
- # Generate an optimized keyword search query based on the chat history and the last question
37
- query_response_token_limit = 500
30
+ async def generate_search_query (
31
+ self , chat_params : ChatParams , query_response_token_limit : int
32
+ ) -> tuple [list [ChatCompletionMessageParam ], Any | str | None , list ]:
33
+ """Generate an optimized keyword search query based on the chat history and the last question"""
38
34
query_messages : list [ChatCompletionMessageParam ] = build_messages (
39
35
model = self .chat_model ,
40
36
system_prompt = self .query_prompt_template ,
@@ -57,6 +53,12 @@ async def run(
57
53
58
54
query_text , filters = extract_search_arguments (chat_params .original_user_query , chat_completion )
59
55
56
+ return query_messages , query_text , filters
57
+
58
+ async def retreive_and_build_context (
59
+ self , chat_params : ChatParams , query_text : str | Any | None , filters : list
60
+ ) -> tuple [list [ChatCompletionMessageParam ], list [Item ]]:
61
+ """Retrieve relevant items from the database and build a context for the chat model."""
60
62
# Retrieve relevant items from the database with the GPT optimized query
61
63
results = await self .searcher .search_and_embed (
62
64
query_text ,
@@ -70,22 +72,40 @@ async def run(
70
72
content = "\n " .join (sources_content )
71
73
72
74
# Generate a contextual and content specific answer using the search results and chat history
73
- response_token_limit = 1024
74
75
contextual_messages : list [ChatCompletionMessageParam ] = build_messages (
75
76
model = self .chat_model ,
76
- system_prompt = overrides . get ( " prompt_template" ) or self . answer_prompt_template ,
77
+ system_prompt = chat_params . prompt_template ,
77
78
new_user_content = chat_params .original_user_query + "\n \n Sources:\n " + content ,
78
79
past_messages = chat_params .past_messages ,
79
- max_tokens = self .chat_token_limit - response_token_limit ,
80
+ max_tokens = self .chat_token_limit - chat_params . response_token_limit ,
80
81
fallback_to_default = True ,
81
82
)
83
+ return contextual_messages , results
84
+
85
+ async def run (
86
+ self ,
87
+ messages : list [ChatCompletionMessageParam ],
88
+ overrides : dict [str , Any ] = {},
89
+ ) -> RetrievalResponse :
90
+ chat_params = self .get_params (messages , overrides )
91
+
92
+ # Generate an optimized keyword search query based on the chat history and the last question
93
+ query_messages , query_text , filters = await self .generate_search_query (
94
+ chat_params = chat_params , query_response_token_limit = 500
95
+ )
96
+
97
+ # Retrieve relevant items from the database with the GPT optimized query
98
+ # Generate a contextual and content specific answer using the search results and chat history
99
+ contextual_messages , results = await self .retreive_and_build_context (
100
+ chat_params = chat_params , query_text = query_text , filters = filters
101
+ )
82
102
83
103
chat_completion_response : ChatCompletion = await self .openai_chat_client .chat .completions .create (
84
104
# Azure OpenAI takes the deployment name as the model name
85
105
model = self .chat_deployment if self .chat_deployment else self .chat_model ,
86
106
messages = contextual_messages ,
87
- temperature = overrides . get ( " temperature" , 0.3 ) ,
88
- max_tokens = response_token_limit ,
107
+ temperature = chat_params . temperature ,
108
+ max_tokens = chat_params . response_token_limit ,
89
109
n = 1 ,
90
110
stream = False ,
91
111
)
@@ -141,50 +161,14 @@ async def run_stream(
141
161
chat_params = self .get_params (messages , overrides )
142
162
143
163
# Generate an optimized keyword search query based on the chat history and the last question
144
- query_response_token_limit = 500
145
- query_messages : list [ChatCompletionMessageParam ] = build_messages (
146
- model = self .chat_model ,
147
- system_prompt = self .query_prompt_template ,
148
- new_user_content = chat_params .original_user_query ,
149
- past_messages = chat_params .past_messages ,
150
- max_tokens = self .chat_token_limit - query_response_token_limit , # TODO: count functions
151
- fallback_to_default = True ,
164
+ query_messages , query_text , filters = await self .generate_search_query (
165
+ chat_params = chat_params , query_response_token_limit = 500
152
166
)
153
167
154
- chat_completion : ChatCompletion = await self .openai_chat_client .chat .completions .create (
155
- messages = query_messages ,
156
- # Azure OpenAI takes the deployment name as the model name
157
- model = self .chat_deployment if self .chat_deployment else self .chat_model ,
158
- temperature = 0.0 , # Minimize creativity for search query generation
159
- max_tokens = query_response_token_limit , # Setting too low risks malformed JSON, too high risks performance
160
- n = 1 ,
161
- tools = build_search_function (),
162
- tool_choice = "auto" ,
163
- )
164
-
165
- query_text , filters = extract_search_arguments (chat_params .original_user_query , chat_completion )
166
-
167
168
# Retrieve relevant items from the database with the GPT optimized query
168
- results = await self .searcher .search_and_embed (
169
- query_text ,
170
- top = chat_params .top ,
171
- enable_vector_search = chat_params .enable_vector_search ,
172
- enable_text_search = chat_params .enable_text_search ,
173
- filters = filters ,
174
- )
175
-
176
- sources_content = [f"[{ (item .id )} ]:{ item .to_str_for_rag ()} \n \n " for item in results ]
177
- content = "\n " .join (sources_content )
178
-
179
169
# Generate a contextual and content specific answer using the search results and chat history
180
- response_token_limit = 1024
181
- contextual_messages : list [ChatCompletionMessageParam ] = build_messages (
182
- model = self .chat_model ,
183
- system_prompt = overrides .get ("prompt_template" ) or self .answer_prompt_template ,
184
- new_user_content = chat_params .original_user_query + "\n \n Sources:\n " + content ,
185
- past_messages = chat_params .past_messages ,
186
- max_tokens = self .chat_token_limit - response_token_limit ,
187
- fallback_to_default = True ,
170
+ contextual_messages , results = await self .retreive_and_build_context (
171
+ chat_params = chat_params , query_text = query_text , filters = filters
188
172
)
189
173
190
174
chat_completion_async_stream : AsyncStream [
@@ -193,8 +177,8 @@ async def run_stream(
193
177
# Azure OpenAI takes the deployment name as the model name
194
178
model = self .chat_deployment if self .chat_deployment else self .chat_model ,
195
179
messages = contextual_messages ,
196
- temperature = overrides . get ( " temperature" , 0.3 ) ,
197
- max_tokens = response_token_limit ,
180
+ temperature = chat_params . temperature ,
181
+ max_tokens = chat_params . response_token_limit ,
198
182
n = 1 ,
199
183
stream = True ,
200
184
)
0 commit comments