Skip to content

Commit c6b1801

Browse files
committed
Port to OpenAI-agents SDK
1 parent 42e7795 commit c6b1801

File tree

1 file changed

+40
-35
lines changed

1 file changed

+40
-35
lines changed

src/backend/fastapi_app/rag_simple.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from collections.abc import AsyncGenerator
22
from typing import Optional, Union
33

4+
from agents import Agent, ModelSettings, OpenAIChatCompletionsModel, Runner, set_tracing_disabled
45
from openai import AsyncAzureOpenAI, AsyncOpenAI
56
from openai.types.chat import ChatCompletionMessageParam
6-
from pydantic_ai import Agent
7-
from pydantic_ai.models.openai import OpenAIModel
8-
from pydantic_ai.providers.openai import OpenAIProvider
9-
from pydantic_ai.settings import ModelSettings
7+
from openai.types.responses import ResponseTextDeltaEvent
108

119
from fastapi_app.api_models import (
1210
AIChatRoles,
@@ -21,6 +19,8 @@
2119
from fastapi_app.postgres_searcher import PostgresSearcher
2220
from fastapi_app.rag_base import RAGChatBase
2321

22+
set_tracing_disabled(disabled=True)
23+
2424

2525
class SimpleRAGChat(RAGChatBase):
2626
def __init__(
@@ -38,17 +38,17 @@ def __init__(
3838
self.model_for_thoughts = (
3939
{"model": chat_model, "deployment": chat_deployment} if chat_deployment else {"model": chat_model}
4040
)
41-
pydantic_chat_model = OpenAIModel(
42-
chat_model if chat_deployment is None else chat_deployment,
43-
provider=OpenAIProvider(openai_client=openai_chat_client),
41+
openai_agents_model = OpenAIChatCompletionsModel(
42+
model=chat_model if chat_deployment is None else chat_deployment, openai_client=openai_chat_client
4443
)
4544
self.answer_agent = Agent(
46-
pydantic_chat_model,
47-
system_prompt=self.answer_prompt_template,
45+
name="Answerer",
46+
instructions=self.answer_prompt_template,
47+
model=openai_agents_model,
4848
model_settings=ModelSettings(
4949
temperature=self.chat_params.temperature,
5050
max_tokens=self.chat_params.response_token_limit,
51-
**({"seed": self.chat_params.seed} if self.chat_params.seed is not None else {}),
51+
extra_body={"seed": self.chat_params.seed} if self.chat_params.seed is not None else {},
5252
),
5353
)
5454

@@ -85,19 +85,21 @@ async def answer(
8585
items: list[ItemPublic],
8686
earlier_thoughts: list[ThoughtStep],
8787
) -> RetrievalResponse:
88-
response = await self.answer_agent.run(
89-
user_prompt=self.prepare_rag_request(self.chat_params.original_user_query, items),
90-
message_history=self.chat_params.past_messages,
88+
run_results = await Runner.run(
89+
self.answer_agent,
90+
input=self.chat_params.past_messages
91+
+ [{"content": self.prepare_rag_request(self.chat_params.original_user_query, items), "role": "user"}],
9192
)
93+
9294
return RetrievalResponse(
93-
message=Message(content=str(response.output), role=AIChatRoles.ASSISTANT),
95+
message=Message(content=str(run_results.final_output), role=AIChatRoles.ASSISTANT),
9496
context=RAGContext(
9597
data_points={item.id: item for item in items},
9698
thoughts=earlier_thoughts
9799
+ [
98100
ThoughtStep(
99101
title="Prompt to generate answer",
100-
description=response.all_messages(),
102+
description=run_results.input,
101103
props=self.model_for_thoughts,
102104
),
103105
],
@@ -109,24 +111,27 @@ async def answer_stream(
109111
items: list[ItemPublic],
110112
earlier_thoughts: list[ThoughtStep],
111113
) -> AsyncGenerator[RetrievalResponseDelta, None]:
112-
async with self.answer_agent.run_stream(
113-
self.prepare_rag_request(self.chat_params.original_user_query, items),
114-
message_history=self.chat_params.past_messages,
115-
) as agent_stream_runner:
116-
yield RetrievalResponseDelta(
117-
context=RAGContext(
118-
data_points={item.id: item for item in items},
119-
thoughts=earlier_thoughts
120-
+ [
121-
ThoughtStep(
122-
title="Prompt to generate answer",
123-
description=agent_stream_runner.all_messages(),
124-
props=self.model_for_thoughts,
125-
),
126-
],
127-
),
128-
)
114+
run_results = Runner.run_streamed(
115+
self.answer_agent,
116+
input=self.chat_params.past_messages
117+
+ [{"content": self.prepare_rag_request(self.chat_params.original_user_query, items), "role": "user"}],
118+
)
119+
120+
yield RetrievalResponseDelta(
121+
context=RAGContext(
122+
data_points={item.id: item for item in items},
123+
thoughts=earlier_thoughts
124+
+ [
125+
ThoughtStep(
126+
title="Prompt to generate answer",
127+
description=run_results.input,
128+
props=self.model_for_thoughts,
129+
),
130+
],
131+
),
132+
)
129133

130-
async for message in agent_stream_runner.stream_text(delta=True, debounce_by=None):
131-
yield RetrievalResponseDelta(delta=Message(content=str(message), role=AIChatRoles.ASSISTANT))
132-
return
134+
async for event in run_results.stream_events():
135+
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
136+
yield RetrievalResponseDelta(delta=Message(content=str(event.data.delta), role=AIChatRoles.ASSISTANT))
137+
return

0 commit comments

Comments
 (0)