1
+ import json
1
2
from collections .abc import AsyncGenerator
2
3
from typing import Optional , Union
3
4
5
+ from agents import Agent , ModelSettings , OpenAIChatCompletionsModel , Runner , function_tool , set_tracing_disabled
4
6
from openai import AsyncAzureOpenAI , AsyncOpenAI
5
- from openai .types .chat import ChatCompletionMessageParam
6
- from pydantic_ai import Agent , RunContext
7
- from pydantic_ai .messages import ModelMessagesTypeAdapter
8
- from pydantic_ai .models .openai import OpenAIModel
9
- from pydantic_ai .providers .openai import OpenAIProvider
10
- from pydantic_ai .settings import ModelSettings
7
+ from openai .types .chat import (
8
+ ChatCompletionMessageParam ,
9
+ )
10
+ from openai .types .responses import (
11
+ EasyInputMessageParam ,
12
+ ResponseFunctionToolCallParam ,
13
+ ResponseTextDeltaEvent ,
14
+ )
15
+ from openai .types .responses .response_input_item_param import FunctionCallOutput
11
16
12
17
from fastapi_app .api_models import (
13
18
AIChatRoles ,
24
29
ThoughtStep ,
25
30
)
26
31
from fastapi_app .postgres_searcher import PostgresSearcher
27
- from fastapi_app .rag_base import ChatParams , RAGChatBase
32
+ from fastapi_app .rag_base import RAGChatBase
33
+
34
+ set_tracing_disabled (disabled = True )
28
35
29
36
30
37
class AdvancedRAGChat (RAGChatBase ):
@@ -46,34 +53,29 @@ def __init__(
46
53
self .model_for_thoughts = (
47
54
{"model" : chat_model , "deployment" : chat_deployment } if chat_deployment else {"model" : chat_model }
48
55
)
49
- pydantic_chat_model = OpenAIModel (
50
- chat_model if chat_deployment is None else chat_deployment ,
51
- provider = OpenAIProvider (openai_client = openai_chat_client ),
56
+ openai_agents_model = OpenAIChatCompletionsModel (
57
+ model = chat_model if chat_deployment is None else chat_deployment , openai_client = openai_chat_client
52
58
)
53
- self .search_agent = Agent [ChatParams , SearchResults ](
54
- pydantic_chat_model ,
55
- model_settings = ModelSettings (
56
- temperature = 0.0 ,
57
- max_tokens = 500 ,
58
- ** ({"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {}),
59
- ),
60
- system_prompt = self .query_prompt_template ,
61
- tools = [self .search_database ],
62
- output_type = SearchResults ,
59
+ self .search_agent = Agent (
60
+ name = "Searcher" ,
61
+ instructions = self .query_prompt_template ,
62
+ tools = [function_tool (self .search_database )],
63
+ tool_use_behavior = "stop_on_first_tool" ,
64
+ model = openai_agents_model ,
63
65
)
64
66
self .answer_agent = Agent (
65
- pydantic_chat_model ,
66
- system_prompt = self .answer_prompt_template ,
67
+ name = "Answerer" ,
68
+ instructions = self .answer_prompt_template ,
69
+ model = openai_agents_model ,
67
70
model_settings = ModelSettings (
68
71
temperature = self .chat_params .temperature ,
69
72
max_tokens = self .chat_params .response_token_limit ,
70
- ** ( {"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {}) ,
73
+ extra_body = {"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {},
71
74
),
72
75
)
73
76
74
77
async def search_database (
75
78
self ,
76
- ctx : RunContext [ChatParams ],
77
79
search_query : str ,
78
80
price_filter : Optional [PriceFilter ] = None ,
79
81
brand_filter : Optional [BrandFilter ] = None ,
@@ -97,66 +99,88 @@ async def search_database(
97
99
filters .append (brand_filter )
98
100
results = await self .searcher .search_and_embed (
99
101
search_query ,
100
- top = ctx . deps .top ,
101
- enable_vector_search = ctx . deps .enable_vector_search ,
102
- enable_text_search = ctx . deps .enable_text_search ,
102
+ top = self . chat_params .top ,
103
+ enable_vector_search = self . chat_params .enable_vector_search ,
104
+ enable_text_search = self . chat_params .enable_text_search ,
103
105
filters = filters ,
104
106
)
105
107
return SearchResults (
106
108
query = search_query , items = [ItemPublic .model_validate (item .to_dict ()) for item in results ], filters = filters
107
109
)
108
110
109
111
async def prepare_context (self ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
110
- few_shots = ModelMessagesTypeAdapter .validate_json (self .query_fewshots )
112
+ few_shots = json .loads (self .query_fewshots )
113
+ few_shot_inputs = []
114
+ for few_shot in few_shots :
115
+ if few_shot ["role" ] == "user" :
116
+ message = EasyInputMessageParam (role = "user" , content = few_shot ["content" ])
117
+ elif few_shot ["role" ] == "assistant" and few_shot ["tool_calls" ] is not None :
118
+ message = ResponseFunctionToolCallParam (
119
+ id = "madeup" ,
120
+ call_id = few_shot ["tool_calls" ][0 ]["id" ],
121
+ name = few_shot ["tool_calls" ][0 ]["function" ]["name" ],
122
+ arguments = few_shot ["tool_calls" ][0 ]["function" ]["arguments" ],
123
+ type = "function_call" ,
124
+ )
125
+ elif few_shot ["role" ] == "tool" and few_shot ["tool_call_id" ] is not None :
126
+ message = FunctionCallOutput (
127
+ id = "madeupoutput" ,
128
+ call_id = few_shot ["tool_call_id" ],
129
+ output = few_shot ["content" ],
130
+ type = "function_call_output" ,
131
+ )
132
+ few_shot_inputs .append (message )
133
+
111
134
user_query = f"Find search results for user query: { self .chat_params .original_user_query } "
112
- results = await self . search_agent . run (
113
- user_query ,
114
- message_history = few_shots + self . chat_params . past_messages ,
115
- deps = self .chat_params ,
116
- )
117
- items = results . output . items
135
+ new_user_message = EasyInputMessageParam ( role = "user" , content = user_query )
136
+ all_messages = few_shot_inputs + self . chat_params . past_messages + [ new_user_message ]
137
+
138
+ run_results = await Runner . run ( self .search_agent , input = all_messages )
139
+ search_results = run_results . new_items [ - 1 ]. output
140
+
118
141
thoughts = [
119
142
ThoughtStep (
120
143
title = "Prompt to generate search arguments" ,
121
- description = results . all_messages () ,
144
+ description = run_results . input ,
122
145
props = self .model_for_thoughts ,
123
146
),
124
147
ThoughtStep (
125
148
title = "Search using generated search arguments" ,
126
- description = results . output .query ,
149
+ description = search_results .query ,
127
150
props = {
128
151
"top" : self .chat_params .top ,
129
152
"vector_search" : self .chat_params .enable_vector_search ,
130
153
"text_search" : self .chat_params .enable_text_search ,
131
- "filters" : results . output .filters ,
154
+ "filters" : search_results .filters ,
132
155
},
133
156
),
134
157
ThoughtStep (
135
158
title = "Search results" ,
136
- description = items ,
159
+ description = search_results . items ,
137
160
),
138
161
]
139
- return items , thoughts
162
+ return search_results . items , thoughts
140
163
141
164
async def answer (
142
165
self ,
143
166
items : list [ItemPublic ],
144
167
earlier_thoughts : list [ThoughtStep ],
145
168
) -> RetrievalResponse :
146
- response = await self .answer_agent .run (
147
- user_prompt = self .prepare_rag_request (self .chat_params .original_user_query , items ),
148
- message_history = self .chat_params .past_messages ,
169
+ run_results = await Runner .run (
170
+ self .answer_agent ,
171
+ input = self .chat_params .past_messages
172
+ + [{"content" : self .prepare_rag_request (self .chat_params .original_user_query , items ), "role" : "user" }],
149
173
)
150
174
151
175
return RetrievalResponse (
152
- message = Message (content = str (response . output ), role = AIChatRoles .ASSISTANT ),
176
+ message = Message (content = str (run_results . final_output ), role = AIChatRoles .ASSISTANT ),
153
177
context = RAGContext (
154
178
data_points = {item .id : item for item in items },
155
179
thoughts = earlier_thoughts
156
180
+ [
157
181
ThoughtStep (
158
182
title = "Prompt to generate answer" ,
159
- description = response . all_messages () ,
183
+ description = run_results . input ,
160
184
props = self .model_for_thoughts ,
161
185
),
162
186
],
@@ -168,24 +192,27 @@ async def answer_stream(
168
192
items : list [ItemPublic ],
169
193
earlier_thoughts : list [ThoughtStep ],
170
194
) -> AsyncGenerator [RetrievalResponseDelta , None ]:
171
- async with self .answer_agent .run_stream (
172
- self .prepare_rag_request (self .chat_params .original_user_query , items ),
173
- message_history = self .chat_params .past_messages ,
174
- ) as agent_stream_runner :
175
- yield RetrievalResponseDelta (
176
- context = RAGContext (
177
- data_points = {item .id : item for item in items },
178
- thoughts = earlier_thoughts
179
- + [
180
- ThoughtStep (
181
- title = "Prompt to generate answer" ,
182
- description = agent_stream_runner .all_messages (),
183
- props = self .model_for_thoughts ,
184
- ),
185
- ],
186
- ),
187
- )
188
-
189
- async for message in agent_stream_runner .stream_text (delta = True , debounce_by = None ):
190
- yield RetrievalResponseDelta (delta = Message (content = str (message ), role = AIChatRoles .ASSISTANT ))
191
- return
195
+ run_results = Runner .run_streamed (
196
+ self .answer_agent ,
197
+ input = self .chat_params .past_messages
198
+ + [{"content" : self .prepare_rag_request (self .chat_params .original_user_query , items ), "role" : "user" }],
199
+ )
200
+
201
+ yield RetrievalResponseDelta (
202
+ context = RAGContext (
203
+ data_points = {item .id : item for item in items },
204
+ thoughts = earlier_thoughts
205
+ + [
206
+ ThoughtStep (
207
+ title = "Prompt to generate answer" ,
208
+ description = run_results .input ,
209
+ props = self .model_for_thoughts ,
210
+ ),
211
+ ],
212
+ ),
213
+ )
214
+
215
+ async for event in run_results .stream_events ():
216
+ if event .type == "raw_response_event" and isinstance (event .data , ResponseTextDeltaEvent ):
217
+ yield RetrievalResponseDelta (delta = Message (content = str (event .data .delta ), role = AIChatRoles .ASSISTANT ))
218
+ return
0 commit comments