1
- import pathlib
2
1
from collections .abc import AsyncGenerator
3
- from typing import (
4
- Any ,
5
- )
2
+ from typing import Any
6
3
7
- from openai import AsyncAzureOpenAI , AsyncOpenAI
8
- from openai .types .chat import ChatCompletion , ChatCompletionMessageParam
4
+ from openai import AsyncAzureOpenAI , AsyncOpenAI , AsyncStream
5
+ from openai .types .chat import ChatCompletion , ChatCompletionChunk , ChatCompletionMessageParam
9
6
from openai_messages_token_helper import build_messages , get_token_limit
10
7
11
- from .api_models import Message , RAGContext , RetrievalResponse , ThoughtStep
12
- from .postgres_searcher import PostgresSearcher
13
- from .query_rewriter import build_search_function , extract_search_arguments
8
+ from fastapi_app .api_models import Message , RAGContext , RetrievalResponse , ThoughtStep
9
+ from fastapi_app .postgres_searcher import PostgresSearcher
10
+ from fastapi_app .query_rewriter import build_search_function , extract_search_arguments
11
+ from fastapi_app .rag_simple import RAGChatBase
14
12
15
13
16
- class AdvancedRAGChat :
14
+ class AdvancedRAGChat ( RAGChatBase ) :
17
15
def __init__ (
18
16
self ,
19
17
* ,
@@ -27,29 +25,21 @@ def __init__(
27
25
self .chat_model = chat_model
28
26
self .chat_deployment = chat_deployment
29
27
self .chat_token_limit = get_token_limit (chat_model , default_to_minimum = True )
30
- current_dir = pathlib .Path (__file__ ).parent
31
- self .query_prompt_template = open (current_dir / "prompts/query.txt" ).read ()
32
- self .answer_prompt_template = open (current_dir / "prompts/answer.txt" ).read ()
33
28
34
29
async def run (
35
- self , messages : list [ChatCompletionMessageParam ], overrides : dict [str , Any ] = {}
36
- ) -> RetrievalResponse | AsyncGenerator [dict [str , Any ], None ]:
37
- text_search = overrides .get ("retrieval_mode" ) in ["text" , "hybrid" , None ]
38
- vector_search = overrides .get ("retrieval_mode" ) in ["vectors" , "hybrid" , None ]
39
- top = overrides .get ("top" , 3 )
40
-
41
- original_user_query = messages [- 1 ]["content" ]
42
- if not isinstance (original_user_query , str ):
43
- raise ValueError ("The most recent message content must be a string." )
44
- past_messages = messages [:- 1 ]
30
+ self ,
31
+ messages : list [ChatCompletionMessageParam ],
32
+ overrides : dict [str , Any ] = {},
33
+ ) -> RetrievalResponse :
34
+ chat_params = self .get_params (messages , overrides )
45
35
46
36
# Generate an optimized keyword search query based on the chat history and the last question
47
37
query_response_token_limit = 500
48
38
query_messages : list [ChatCompletionMessageParam ] = build_messages (
49
39
model = self .chat_model ,
50
40
system_prompt = self .query_prompt_template ,
51
- new_user_content = original_user_query ,
52
- past_messages = past_messages ,
41
+ new_user_content = chat_params . original_user_query ,
42
+ past_messages = chat_params . past_messages ,
53
43
max_tokens = self .chat_token_limit - query_response_token_limit , # TODO: count functions
54
44
fallback_to_default = True ,
55
45
)
@@ -65,14 +55,14 @@ async def run(
65
55
tool_choice = "auto" ,
66
56
)
67
57
68
- query_text , filters = extract_search_arguments (original_user_query , chat_completion )
58
+ query_text , filters = extract_search_arguments (chat_params . original_user_query , chat_completion )
69
59
70
60
# Retrieve relevant items from the database with the GPT optimized query
71
61
results = await self .searcher .search_and_embed (
72
62
query_text ,
73
- top = top ,
74
- enable_vector_search = vector_search ,
75
- enable_text_search = text_search ,
63
+ top = chat_params . top ,
64
+ enable_vector_search = chat_params . enable_vector_search ,
65
+ enable_text_search = chat_params . enable_text_search ,
76
66
filters = filters ,
77
67
)
78
68
@@ -84,8 +74,8 @@ async def run(
84
74
contextual_messages : list [ChatCompletionMessageParam ] = build_messages (
85
75
model = self .chat_model ,
86
76
system_prompt = overrides .get ("prompt_template" ) or self .answer_prompt_template ,
87
- new_user_content = original_user_query + "\n \n Sources:\n " + content ,
88
- past_messages = past_messages ,
77
+ new_user_content = chat_params . original_user_query + "\n \n Sources:\n " + content ,
78
+ past_messages = chat_params . past_messages ,
89
79
max_tokens = self .chat_token_limit - response_token_limit ,
90
80
fallback_to_default = True ,
91
81
)
@@ -99,6 +89,7 @@ async def run(
99
89
n = 1 ,
100
90
stream = False ,
101
91
)
92
+
102
93
first_choice_message = chat_completion_response .choices [0 ].message
103
94
104
95
return RetrievalResponse (
@@ -119,9 +110,9 @@ async def run(
119
110
title = "Search using generated search arguments" ,
120
111
description = query_text ,
121
112
props = {
122
- "top" : top ,
123
- "vector_search" : vector_search ,
124
- "text_search" : text_search ,
113
+ "top" : chat_params . top ,
114
+ "vector_search" : chat_params . enable_vector_search ,
115
+ "text_search" : chat_params . enable_text_search ,
125
116
"filters" : filters ,
126
117
},
127
118
),
@@ -141,3 +132,114 @@ async def run(
141
132
],
142
133
),
143
134
)
135
+
136
+ async def run_stream (
137
+ self ,
138
+ messages : list [ChatCompletionMessageParam ],
139
+ overrides : dict [str , Any ] = {},
140
+ ) -> AsyncGenerator [RetrievalResponse | Message , None ]:
141
+ chat_params = self .get_params (messages , overrides )
142
+
143
+ # Generate an optimized keyword search query based on the chat history and the last question
144
+ query_response_token_limit = 500
145
+ query_messages : list [ChatCompletionMessageParam ] = build_messages (
146
+ model = self .chat_model ,
147
+ system_prompt = self .query_prompt_template ,
148
+ new_user_content = chat_params .original_user_query ,
149
+ past_messages = chat_params .past_messages ,
150
+ max_tokens = self .chat_token_limit - query_response_token_limit , # TODO: count functions
151
+ fallback_to_default = True ,
152
+ )
153
+
154
+ chat_completion : ChatCompletion = await self .openai_chat_client .chat .completions .create (
155
+ messages = query_messages ,
156
+ # Azure OpenAI takes the deployment name as the model name
157
+ model = self .chat_deployment if self .chat_deployment else self .chat_model ,
158
+ temperature = 0.0 , # Minimize creativity for search query generation
159
+ max_tokens = query_response_token_limit , # Setting too low risks malformed JSON, too high risks performance
160
+ n = 1 ,
161
+ tools = build_search_function (),
162
+ tool_choice = "auto" ,
163
+ )
164
+
165
+ query_text , filters = extract_search_arguments (chat_params .original_user_query , chat_completion )
166
+
167
+ # Retrieve relevant items from the database with the GPT optimized query
168
+ results = await self .searcher .search_and_embed (
169
+ query_text ,
170
+ top = chat_params .top ,
171
+ enable_vector_search = chat_params .enable_vector_search ,
172
+ enable_text_search = chat_params .enable_text_search ,
173
+ filters = filters ,
174
+ )
175
+
176
+ sources_content = [f"[{ (item .id )} ]:{ item .to_str_for_rag ()} \n \n " for item in results ]
177
+ content = "\n " .join (sources_content )
178
+
179
+ # Generate a contextual and content specific answer using the search results and chat history
180
+ response_token_limit = 1024
181
+ contextual_messages : list [ChatCompletionMessageParam ] = build_messages (
182
+ model = self .chat_model ,
183
+ system_prompt = overrides .get ("prompt_template" ) or self .answer_prompt_template ,
184
+ new_user_content = chat_params .original_user_query + "\n \n Sources:\n " + content ,
185
+ past_messages = chat_params .past_messages ,
186
+ max_tokens = self .chat_token_limit - response_token_limit ,
187
+ fallback_to_default = True ,
188
+ )
189
+
190
+ chat_completion_async_stream : AsyncStream [
191
+ ChatCompletionChunk
192
+ ] = await self .openai_chat_client .chat .completions .create (
193
+ # Azure OpenAI takes the deployment name as the model name
194
+ model = self .chat_deployment if self .chat_deployment else self .chat_model ,
195
+ messages = contextual_messages ,
196
+ temperature = overrides .get ("temperature" , 0.3 ),
197
+ max_tokens = response_token_limit ,
198
+ n = 1 ,
199
+ stream = True ,
200
+ )
201
+
202
+ yield RetrievalResponse (
203
+ message = Message (content = "" , role = "assistant" ),
204
+ context = RAGContext (
205
+ data_points = {item .id : item .to_dict () for item in results },
206
+ thoughts = [
207
+ ThoughtStep (
208
+ title = "Prompt to generate search arguments" ,
209
+ description = [str (message ) for message in query_messages ],
210
+ props = (
211
+ {"model" : self .chat_model , "deployment" : self .chat_deployment }
212
+ if self .chat_deployment
213
+ else {"model" : self .chat_model }
214
+ ),
215
+ ),
216
+ ThoughtStep (
217
+ title = "Search using generated search arguments" ,
218
+ description = query_text ,
219
+ props = {
220
+ "top" : chat_params .top ,
221
+ "vector_search" : chat_params .enable_vector_search ,
222
+ "text_search" : chat_params .enable_text_search ,
223
+ "filters" : filters ,
224
+ },
225
+ ),
226
+ ThoughtStep (
227
+ title = "Search results" ,
228
+ description = [result .to_dict () for result in results ],
229
+ ),
230
+ ThoughtStep (
231
+ title = "Prompt to generate answer" ,
232
+ description = [str (message ) for message in contextual_messages ],
233
+ props = (
234
+ {"model" : self .chat_model , "deployment" : self .chat_deployment }
235
+ if self .chat_deployment
236
+ else {"model" : self .chat_model }
237
+ ),
238
+ ),
239
+ ],
240
+ ),
241
+ )
242
+
243
+ async for response_chunk in chat_completion_async_stream :
244
+ yield Message (content = str (response_chunk .choices [0 ].delta .content ), role = "assistant" )
245
+ return
0 commit comments