1
1
from collections .abc import AsyncGenerator
2
2
from typing import Optional , Union
3
3
4
+ from agents import Agent , ModelSettings , OpenAIChatCompletionsModel , Runner , set_tracing_disabled
4
5
from openai import AsyncAzureOpenAI , AsyncOpenAI
5
6
from openai .types .chat import ChatCompletionMessageParam
6
- from pydantic_ai import Agent
7
- from pydantic_ai .models .openai import OpenAIModel
8
- from pydantic_ai .providers .openai import OpenAIProvider
9
- from pydantic_ai .settings import ModelSettings
7
+ from openai .types .responses import ResponseTextDeltaEvent
10
8
11
9
from fastapi_app .api_models import (
12
10
AIChatRoles ,
21
19
from fastapi_app .postgres_searcher import PostgresSearcher
22
20
from fastapi_app .rag_base import RAGChatBase
23
21
22
+ set_tracing_disabled (disabled = True )
23
+
24
24
25
25
class SimpleRAGChat (RAGChatBase ):
26
26
def __init__ (
@@ -38,17 +38,17 @@ def __init__(
38
38
self .model_for_thoughts = (
39
39
{"model" : chat_model , "deployment" : chat_deployment } if chat_deployment else {"model" : chat_model }
40
40
)
41
- pydantic_chat_model = OpenAIModel (
42
- chat_model if chat_deployment is None else chat_deployment ,
43
- provider = OpenAIProvider (openai_client = openai_chat_client ),
41
+ openai_agents_model = OpenAIChatCompletionsModel (
42
+ model = chat_model if chat_deployment is None else chat_deployment , openai_client = openai_chat_client
44
43
)
45
44
self .answer_agent = Agent (
46
- pydantic_chat_model ,
47
- system_prompt = self .answer_prompt_template ,
45
+ name = "Answerer" ,
46
+ instructions = self .answer_prompt_template ,
47
+ model = openai_agents_model ,
48
48
model_settings = ModelSettings (
49
49
temperature = self .chat_params .temperature ,
50
50
max_tokens = self .chat_params .response_token_limit ,
51
- ** ( {"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {}) ,
51
+ extra_body = {"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {},
52
52
),
53
53
)
54
54
@@ -85,19 +85,21 @@ async def answer(
85
85
items : list [ItemPublic ],
86
86
earlier_thoughts : list [ThoughtStep ],
87
87
) -> RetrievalResponse :
88
- response = await self .answer_agent .run (
89
- user_prompt = self .prepare_rag_request (self .chat_params .original_user_query , items ),
90
- message_history = self .chat_params .past_messages ,
88
+ run_results = await Runner .run (
89
+ self .answer_agent ,
90
+ input = self .chat_params .past_messages
91
+ + [{"content" : self .prepare_rag_request (self .chat_params .original_user_query , items ), "role" : "user" }],
91
92
)
93
+
92
94
return RetrievalResponse (
93
- message = Message (content = str (response . output ), role = AIChatRoles .ASSISTANT ),
95
+ message = Message (content = str (run_results . final_output ), role = AIChatRoles .ASSISTANT ),
94
96
context = RAGContext (
95
97
data_points = {item .id : item for item in items },
96
98
thoughts = earlier_thoughts
97
99
+ [
98
100
ThoughtStep (
99
101
title = "Prompt to generate answer" ,
100
- description = response . all_messages () ,
102
+ description = run_results . input ,
101
103
props = self .model_for_thoughts ,
102
104
),
103
105
],
@@ -109,24 +111,27 @@ async def answer_stream(
109
111
items : list [ItemPublic ],
110
112
earlier_thoughts : list [ThoughtStep ],
111
113
) -> AsyncGenerator [RetrievalResponseDelta , None ]:
112
- async with self .answer_agent .run_stream (
113
- self .prepare_rag_request (self .chat_params .original_user_query , items ),
114
- message_history = self .chat_params .past_messages ,
115
- ) as agent_stream_runner :
116
- yield RetrievalResponseDelta (
117
- context = RAGContext (
118
- data_points = {item .id : item for item in items },
119
- thoughts = earlier_thoughts
120
- + [
121
- ThoughtStep (
122
- title = "Prompt to generate answer" ,
123
- description = agent_stream_runner .all_messages (),
124
- props = self .model_for_thoughts ,
125
- ),
126
- ],
127
- ),
128
- )
114
+ run_results = Runner .run_streamed (
115
+ self .answer_agent ,
116
+ input = self .chat_params .past_messages
117
+ + [{"content" : self .prepare_rag_request (self .chat_params .original_user_query , items ), "role" : "user" }],
118
+ )
119
+
120
+ yield RetrievalResponseDelta (
121
+ context = RAGContext (
122
+ data_points = {item .id : item for item in items },
123
+ thoughts = earlier_thoughts
124
+ + [
125
+ ThoughtStep (
126
+ title = "Prompt to generate answer" ,
127
+ description = run_results .input ,
128
+ props = self .model_for_thoughts ,
129
+ ),
130
+ ],
131
+ ),
132
+ )
129
133
130
- async for message in agent_stream_runner .stream_text (delta = True , debounce_by = None ):
131
- yield RetrievalResponseDelta (delta = Message (content = str (message ), role = AIChatRoles .ASSISTANT ))
132
- return
134
+ async for event in run_results .stream_events ():
135
+ if event .type == "raw_response_event" and isinstance (event .data , ResponseTextDeltaEvent ):
136
+ yield RetrievalResponseDelta (delta = Message (content = str (event .data .delta ), role = AIChatRoles .ASSISTANT ))
137
+ return
0 commit comments