12
12
13
13
from fastapi_app .api_models import (
14
14
AIChatRoles ,
15
+ ItemPublic ,
15
16
Message ,
16
17
RAGContext ,
17
18
RetrievalResponse ,
@@ -50,6 +51,14 @@ class BrandFilter(TypedDict):
50
51
"""The brand name to compare against (e.g., 'AirStrider')"""
51
52
52
53
54
+ class SearchResults (TypedDict ):
55
+ items : list [ItemPublic ]
56
+ """List of items that match the search query and filters"""
57
+
58
+ filters : list [Union [PriceFilter , BrandFilter ]]
59
+ """List of filters applied to the search results"""
60
+
61
+
53
62
class AdvancedRAGChat (RAGChatBase ):
54
63
def __init__ (
55
64
self ,
@@ -71,7 +80,7 @@ async def search_database(
71
80
search_query : str ,
72
81
price_filter : Optional [PriceFilter ] = None ,
73
82
brand_filter : Optional [BrandFilter ] = None ,
74
- ) -> list [ str ] :
83
+ ) -> SearchResults :
75
84
"""
76
85
Search PostgreSQL database for relevant products based on user query
77
86
@@ -83,7 +92,6 @@ async def search_database(
83
92
Returns:
84
93
List of formatted items that match the search query and filters
85
94
"""
86
- print (search_query , price_filter , brand_filter )
87
95
# Only send non-None filters
88
96
filters = []
89
97
if price_filter :
@@ -97,9 +105,9 @@ async def search_database(
97
105
enable_text_search = ctx .deps .enable_text_search ,
98
106
filters = filters ,
99
107
)
100
- return [ f"[ { (item .id ) } ]: { item . to_str_for_rag () } \n \n " for item in results ]
108
+ return SearchResults ( items = [ ItemPublic . model_validate (item .to_dict ()) for item in results ], filters = filters )
101
109
102
- async def prepare_context (self , chat_params : ChatParams ) -> tuple [str , list [Item ], list [ThoughtStep ]]:
110
+ async def prepare_context (self , chat_params : ChatParams ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
103
111
model = OpenAIModel (
104
112
os .environ ["AZURE_OPENAI_CHAT_DEPLOYMENT" ], provider = OpenAIProvider (openai_client = self .openai_chat_client )
105
113
)
@@ -108,17 +116,15 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[str, list[Item
108
116
model_settings = ModelSettings (temperature = 0.0 , max_tokens = 500 , seed = chat_params .seed ),
109
117
system_prompt = self .query_prompt_template ,
110
118
tools = [self .search_database ],
111
- output_type = list [ str ] ,
119
+ output_type = SearchResults ,
112
120
)
113
121
# TODO: Provide few-shot examples
114
122
results = await agent .run (
115
123
f"Find search results for user query: { chat_params .original_user_query } " ,
116
124
# message_history=chat_params.past_messages, # TODO
117
125
deps = chat_params ,
118
126
)
119
- if not isinstance (results , list ):
120
- raise ValueError ("Search results should be a list of strings" )
121
-
127
+ items = results .output .items
122
128
thoughts = [
123
129
ThoughtStep (
124
130
title = "Prompt to generate search arguments" ,
@@ -144,12 +150,12 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[str, list[Item
144
150
description = "" , # TODO
145
151
),
146
152
]
147
- return results , thoughts
153
+ return items , thoughts
148
154
149
155
async def answer (
150
156
self ,
151
157
chat_params : ChatParams ,
152
- results : list [str ],
158
+ items : list [ItemPublic ],
153
159
earlier_thoughts : list [ThoughtStep ],
154
160
) -> RetrievalResponse :
155
161
agent = Agent (
@@ -163,15 +169,16 @@ async def answer(
163
169
),
164
170
)
165
171
172
+ item_references = [item .to_str_for_rag () for item in items ]
166
173
response = await agent .run (
167
- user_prompt = chat_params .original_user_query + "Sources:\n " + "\n " .join (results ),
174
+ user_prompt = chat_params .original_user_query + "Sources:\n " + "\n " .join (item_references ),
168
175
message_history = chat_params .past_messages ,
169
176
)
170
177
171
178
return RetrievalResponse (
172
179
message = Message (content = str (response .output ), role = AIChatRoles .ASSISTANT ),
173
180
context = RAGContext (
174
- data_points = {item . id : item . to_dict () for item in results },
181
+ data_points = {}, # TODO
175
182
thoughts = earlier_thoughts
176
183
+ [
177
184
ThoughtStep (
0 commit comments