Skip to content

Commit 42e7795

Browse files
committed
Port to OpenAI-agents SDK
1 parent b000a71 commit 42e7795

File tree

6 files changed

+123
-153
lines changed

6 files changed

+123
-153
lines changed

src/backend/fastapi_app/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def create_app(testing: bool = False):
5858
else:
5959
if not testing:
6060
load_dotenv(override=True)
61-
logging.basicConfig(level=logging.INFO)
61+
logging.basicConfig(level=logging.DEBUG)
6262

6363
# Turn off particularly noisy INFO level logs from Azure Core SDK:
6464
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)

src/backend/fastapi_app/api_models.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from enum import Enum
2-
from typing import Any, Optional, Union
2+
from typing import Any, Optional
33

44
from openai.types.chat import ChatCompletionMessageParam
55
from pydantic import BaseModel, Field
6-
from pydantic_ai.messages import ModelRequest, ModelResponse
76

87

98
class AIChatRoles(str, Enum):
@@ -96,7 +95,7 @@ class ChatParams(ChatRequestOverrides):
9695
enable_text_search: bool
9796
enable_vector_search: bool
9897
original_user_query: str
99-
past_messages: list[Union[ModelRequest, ModelResponse]]
98+
past_messages: list[ChatCompletionMessageParam]
10099

101100

102101
class Filter(BaseModel):
Lines changed: 24 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,34 @@
11
[
2-
{
3-
"parts": [
4-
{
5-
"content": "good options for climbing gear that can be used outside?",
6-
"timestamp": "2025-05-07T19:02:46.977501Z",
7-
"part_kind": "user-prompt"
8-
}
9-
],
10-
"instructions": null,
11-
"kind": "request"
12-
},
13-
{
14-
"parts": [
15-
{
16-
"tool_name": "search_database",
17-
"args": "{\"search_query\":\"climbing gear outside\"}",
18-
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
19-
"part_kind": "tool-call"
20-
}
21-
],
22-
"model_name": "gpt-4o-mini-2024-07-18",
23-
"timestamp": "2025-05-07T19:02:47Z",
24-
"kind": "response"
25-
},
26-
{
27-
"parts": [
2+
{"role": "user", "content": "good options for climbing gear that can be used outside?"},
3+
{"role": "assistant", "tool_calls": [
284
{
29-
"tool_name": "search_database",
30-
"content": "Search results for climbing gear that can be used outside: ...",
31-
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
32-
"timestamp": "2025-05-07T19:02:48.242408Z",
33-
"part_kind": "tool-return"
5+
"id": "call_abc123",
6+
"type": "function",
7+
"function": {
8+
"arguments": "{\"search_query\":\"climbing gear outside\"}",
9+
"name": "search_database"
10+
}
3411
}
35-
],
36-
"instructions": null,
37-
"kind": "request"
38-
},
12+
]},
3913
{
40-
"parts": [
41-
{
42-
"content": "are there any shoes less than $50?",
43-
"timestamp": "2025-05-07T19:02:46.977501Z",
44-
"part_kind": "user-prompt"
45-
}
46-
],
47-
"instructions": null,
48-
"kind": "request"
14+
"role": "tool",
15+
"tool_call_id": "call_abc123",
16+
"content": "Search results for climbing gear that can be used outside: ..."
4917
},
50-
{
51-
"parts": [
18+
{"role": "user", "content": "are there any shoes less than $50?"},
19+
{"role": "assistant", "tool_calls": [
5220
{
53-
"tool_name": "search_database",
54-
"args": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
55-
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
56-
"part_kind": "tool-call"
21+
"id": "call_abc456",
22+
"type": "function",
23+
"function": {
24+
"arguments": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
25+
"name": "search_database"
26+
}
5727
}
58-
],
59-
"model_name": "gpt-4o-mini-2024-07-18",
60-
"timestamp": "2025-05-07T19:02:47Z",
61-
"kind": "response"
62-
},
28+
]},
6329
{
64-
"parts": [
65-
{
66-
"tool_name": "search_database",
67-
"content": "Search results for shoes cheaper than 50: ...",
68-
"tool_call_id": "call_4HeBCmo2uioV6CyoePEGyZPc",
69-
"timestamp": "2025-05-07T19:02:48.242408Z",
70-
"part_kind": "tool-return"
71-
}
72-
],
73-
"instructions": null,
74-
"kind": "request"
30+
"role": "tool",
31+
"tool_call_id": "call_abc456",
32+
"content": "Search results for shoes cheaper than 50: ..."
7533
}
7634
]
Lines changed: 92 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1+
import json
12
from collections.abc import AsyncGenerator
23
from typing import Optional, Union
34

5+
from agents import Agent, ModelSettings, OpenAIChatCompletionsModel, Runner, function_tool, set_tracing_disabled
46
from openai import AsyncAzureOpenAI, AsyncOpenAI
5-
from openai.types.chat import ChatCompletionMessageParam
6-
from pydantic_ai import Agent, RunContext
7-
from pydantic_ai.messages import ModelMessagesTypeAdapter
8-
from pydantic_ai.models.openai import OpenAIModel
9-
from pydantic_ai.providers.openai import OpenAIProvider
10-
from pydantic_ai.settings import ModelSettings
7+
from openai.types.chat import (
8+
ChatCompletionMessageParam,
9+
)
10+
from openai.types.responses import (
11+
EasyInputMessageParam,
12+
ResponseFunctionToolCallParam,
13+
ResponseTextDeltaEvent,
14+
)
15+
from openai.types.responses.response_input_item_param import FunctionCallOutput
1116

1217
from fastapi_app.api_models import (
1318
AIChatRoles,
@@ -24,7 +29,9 @@
2429
ThoughtStep,
2530
)
2631
from fastapi_app.postgres_searcher import PostgresSearcher
27-
from fastapi_app.rag_base import ChatParams, RAGChatBase
32+
from fastapi_app.rag_base import RAGChatBase
33+
34+
set_tracing_disabled(disabled=True)
2835

2936

3037
class AdvancedRAGChat(RAGChatBase):
@@ -46,34 +53,29 @@ def __init__(
4653
self.model_for_thoughts = (
4754
{"model": chat_model, "deployment": chat_deployment} if chat_deployment else {"model": chat_model}
4855
)
49-
pydantic_chat_model = OpenAIModel(
50-
chat_model if chat_deployment is None else chat_deployment,
51-
provider=OpenAIProvider(openai_client=openai_chat_client),
56+
openai_agents_model = OpenAIChatCompletionsModel(
57+
model=chat_model if chat_deployment is None else chat_deployment, openai_client=openai_chat_client
5258
)
53-
self.search_agent = Agent[ChatParams, SearchResults](
54-
pydantic_chat_model,
55-
model_settings=ModelSettings(
56-
temperature=0.0,
57-
max_tokens=500,
58-
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
59-
),
60-
system_prompt=self.query_prompt_template,
61-
tools=[self.search_database],
62-
output_type=SearchResults,
59+
self.search_agent = Agent(
60+
name="Searcher",
61+
instructions=self.query_prompt_template,
62+
tools=[function_tool(self.search_database)],
63+
tool_use_behavior="stop_on_first_tool",
64+
model=openai_agents_model,
6365
)
6466
self.answer_agent = Agent(
65-
pydantic_chat_model,
66-
system_prompt=self.answer_prompt_template,
67+
name="Answerer",
68+
instructions=self.answer_prompt_template,
69+
model=openai_agents_model,
6770
model_settings=ModelSettings(
6871
temperature=self.chat_params.temperature,
6972
max_tokens=self.chat_params.response_token_limit,
70-
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
73+
extra_body={"seed": self.chat_params.seed} if self.chat_params.seed is not None else {},
7174
),
7275
)
7376

7477
async def search_database(
7578
self,
76-
ctx: RunContext[ChatParams],
7779
search_query: str,
7880
price_filter: Optional[PriceFilter] = None,
7981
brand_filter: Optional[BrandFilter] = None,
@@ -97,66 +99,88 @@ async def search_database(
9799
filters.append(brand_filter)
98100
results = await self.searcher.search_and_embed(
99101
search_query,
100-
top=ctx.deps.top,
101-
enable_vector_search=ctx.deps.enable_vector_search,
102-
enable_text_search=ctx.deps.enable_text_search,
102+
top=self.chat_params.top,
103+
enable_vector_search=self.chat_params.enable_vector_search,
104+
enable_text_search=self.chat_params.enable_text_search,
103105
filters=filters,
104106
)
105107
return SearchResults(
106108
query=search_query, items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters
107109
)
108110

109111
async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
110-
few_shots = ModelMessagesTypeAdapter.validate_json(self.query_fewshots)
112+
few_shots = json.loads(self.query_fewshots)
113+
few_shot_inputs = []
114+
for few_shot in few_shots:
115+
if few_shot["role"] == "user":
116+
message = EasyInputMessageParam(role="user", content=few_shot["content"])
117+
elif few_shot["role"] == "assistant" and few_shot["tool_calls"] is not None:
118+
message = ResponseFunctionToolCallParam(
119+
id="madeup",
120+
call_id=few_shot["tool_calls"][0]["id"],
121+
name=few_shot["tool_calls"][0]["function"]["name"],
122+
arguments=few_shot["tool_calls"][0]["function"]["arguments"],
123+
type="function_call",
124+
)
125+
elif few_shot["role"] == "tool" and few_shot["tool_call_id"] is not None:
126+
message = FunctionCallOutput(
127+
id="madeupoutput",
128+
call_id=few_shot["tool_call_id"],
129+
output=few_shot["content"],
130+
type="function_call_output",
131+
)
132+
few_shot_inputs.append(message)
133+
111134
user_query = f"Find search results for user query: {self.chat_params.original_user_query}"
112-
results = await self.search_agent.run(
113-
user_query,
114-
message_history=few_shots + self.chat_params.past_messages,
115-
deps=self.chat_params,
116-
)
117-
items = results.output.items
135+
new_user_message = EasyInputMessageParam(role="user", content=user_query)
136+
all_messages = few_shot_inputs + self.chat_params.past_messages + [new_user_message]
137+
138+
run_results = await Runner.run(self.search_agent, input=all_messages)
139+
search_results = run_results.new_items[-1].output
140+
118141
thoughts = [
119142
ThoughtStep(
120143
title="Prompt to generate search arguments",
121-
description=results.all_messages(),
144+
description=run_results.input,
122145
props=self.model_for_thoughts,
123146
),
124147
ThoughtStep(
125148
title="Search using generated search arguments",
126-
description=results.output.query,
149+
description=search_results.query,
127150
props={
128151
"top": self.chat_params.top,
129152
"vector_search": self.chat_params.enable_vector_search,
130153
"text_search": self.chat_params.enable_text_search,
131-
"filters": results.output.filters,
154+
"filters": search_results.filters,
132155
},
133156
),
134157
ThoughtStep(
135158
title="Search results",
136-
description=items,
159+
description=search_results.items,
137160
),
138161
]
139-
return items, thoughts
162+
return search_results.items, thoughts
140163

141164
async def answer(
142165
self,
143166
items: list[ItemPublic],
144167
earlier_thoughts: list[ThoughtStep],
145168
) -> RetrievalResponse:
146-
response = await self.answer_agent.run(
147-
user_prompt=self.prepare_rag_request(self.chat_params.original_user_query, items),
148-
message_history=self.chat_params.past_messages,
169+
run_results = await Runner.run(
170+
self.answer_agent,
171+
input=self.chat_params.past_messages
172+
+ [{"content": self.prepare_rag_request(self.chat_params.original_user_query, items), "role": "user"}],
149173
)
150174

151175
return RetrievalResponse(
152-
message=Message(content=str(response.output), role=AIChatRoles.ASSISTANT),
176+
message=Message(content=str(run_results.final_output), role=AIChatRoles.ASSISTANT),
153177
context=RAGContext(
154178
data_points={item.id: item for item in items},
155179
thoughts=earlier_thoughts
156180
+ [
157181
ThoughtStep(
158182
title="Prompt to generate answer",
159-
description=response.all_messages(),
183+
description=run_results.input,
160184
props=self.model_for_thoughts,
161185
),
162186
],
@@ -168,24 +192,27 @@ async def answer_stream(
168192
items: list[ItemPublic],
169193
earlier_thoughts: list[ThoughtStep],
170194
) -> AsyncGenerator[RetrievalResponseDelta, None]:
171-
async with self.answer_agent.run_stream(
172-
self.prepare_rag_request(self.chat_params.original_user_query, items),
173-
message_history=self.chat_params.past_messages,
174-
) as agent_stream_runner:
175-
yield RetrievalResponseDelta(
176-
context=RAGContext(
177-
data_points={item.id: item for item in items},
178-
thoughts=earlier_thoughts
179-
+ [
180-
ThoughtStep(
181-
title="Prompt to generate answer",
182-
description=agent_stream_runner.all_messages(),
183-
props=self.model_for_thoughts,
184-
),
185-
],
186-
),
187-
)
188-
189-
async for message in agent_stream_runner.stream_text(delta=True, debounce_by=None):
190-
yield RetrievalResponseDelta(delta=Message(content=str(message), role=AIChatRoles.ASSISTANT))
191-
return
195+
run_results = Runner.run_streamed(
196+
self.answer_agent,
197+
input=self.chat_params.past_messages
198+
+ [{"content": self.prepare_rag_request(self.chat_params.original_user_query, items), "role": "user"}],
199+
)
200+
201+
yield RetrievalResponseDelta(
202+
context=RAGContext(
203+
data_points={item.id: item for item in items},
204+
thoughts=earlier_thoughts
205+
+ [
206+
ThoughtStep(
207+
title="Prompt to generate answer",
208+
description=run_results.input,
209+
props=self.model_for_thoughts,
210+
),
211+
],
212+
),
213+
)
214+
215+
async for event in run_results.stream_events():
216+
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
217+
yield RetrievalResponseDelta(delta=Message(content=str(event.data.delta), role=AIChatRoles.ASSISTANT))
218+
return

0 commit comments

Comments
 (0)