Skip to content

Commit 7bb7453

Browse files
committed
Fix tests, mypy
1 parent c6b1801 commit 7bb7453

File tree

10 files changed

+93
-303
lines changed

10 files changed

+93
-303
lines changed

src/backend/fastapi_app/api_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22
from typing import Any, Optional
33

4-
from openai.types.chat import ChatCompletionMessageParam
4+
from openai.types.responses import ResponseInputItemParam
55
from pydantic import BaseModel, Field
66

77

@@ -36,7 +36,7 @@ class ChatRequestContext(BaseModel):
3636

3737

3838
class ChatRequest(BaseModel):
39-
messages: list[ChatCompletionMessageParam]
39+
messages: list[ResponseInputItemParam]
4040
context: ChatRequestContext
4141
sessionState: Optional[Any] = None
4242

@@ -95,7 +95,7 @@ class ChatParams(ChatRequestOverrides):
9595
enable_text_search: bool
9696
enable_vector_search: bool
9797
original_user_query: str
98-
past_messages: list[ChatCompletionMessageParam]
98+
past_messages: list[ResponseInputItemParam]
9999

100100

101101
class Filter(BaseModel):
Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
11
[
2-
{"role": "user", "content": "good options for climbing gear that can be used outside?"},
3-
{"role": "assistant", "tool_calls": [
4-
{
5-
"id": "call_abc123",
6-
"type": "function",
7-
"function": {
8-
"arguments": "{\"search_query\":\"climbing gear outside\"}",
9-
"name": "search_database"
10-
}
11-
}
12-
]},
132
{
14-
"role": "tool",
15-
"tool_call_id": "call_abc123",
16-
"content": "Search results for climbing gear that can be used outside: ..."
3+
"role": "user",
4+
"content": "good options for climbing gear that can be used outside?"
175
},
18-
{"role": "user", "content": "are there any shoes less than $50?"},
19-
{"role": "assistant", "tool_calls": [
20-
{
21-
"id": "call_abc456",
22-
"type": "function",
23-
"function": {
24-
"arguments": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
25-
"name": "search_database"
26-
}
27-
}
28-
]},
296
{
30-
"role": "tool",
31-
"tool_call_id": "call_abc456",
32-
"content": "Search results for shoes cheaper than 50: ..."
7+
"id": "madeup",
8+
"call_id": "call_abc123",
9+
"name": "search_database",
10+
"arguments": "{\"search_query\":\"climbing gear outside\"}",
11+
"type": "function_call"
12+
},
13+
{
14+
"id": "madeupoutput",
15+
"call_id": "call_abc123",
16+
"output": "Search results for climbing gear that can be used outside: ...",
17+
"type": "function_call_output"
18+
},
19+
{
20+
"role": "user",
21+
"content": "are there any shoes less than $50?"
22+
},
23+
{
24+
"id": "madeup",
25+
"call_id": "call_abc456",
26+
"name": "search_database",
27+
"arguments": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
28+
"type": "function_call"
29+
},
30+
{
31+
"id": "madeupoutput",
32+
"call_id": "call_abc456",
33+
"output": "Search results for shoes cheaper than 50: ...",
34+
"type": "function_call_output"
3335
}
3436
]

src/backend/fastapi_app/rag_advanced.py

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
from collections.abc import AsyncGenerator
33
from typing import Optional, Union
44

5-
from agents import Agent, ModelSettings, OpenAIChatCompletionsModel, Runner, function_tool, set_tracing_disabled
6-
from openai import AsyncAzureOpenAI, AsyncOpenAI
7-
from openai.types.chat import (
8-
ChatCompletionMessageParam,
9-
)
10-
from openai.types.responses import (
11-
EasyInputMessageParam,
12-
ResponseFunctionToolCallParam,
13-
ResponseTextDeltaEvent,
5+
from agents import (
6+
Agent,
7+
ModelSettings,
8+
OpenAIChatCompletionsModel,
9+
Runner,
10+
ToolCallOutputItem,
11+
function_tool,
12+
set_tracing_disabled,
1413
)
15-
from openai.types.responses.response_input_item_param import FunctionCallOutput
14+
from openai import AsyncAzureOpenAI, AsyncOpenAI
15+
from openai.types.responses import EasyInputMessageParam, ResponseInputItemParam, ResponseTextDeltaEvent
1616

1717
from fastapi_app.api_models import (
1818
AIChatRoles,
@@ -41,7 +41,7 @@ class AdvancedRAGChat(RAGChatBase):
4141
def __init__(
4242
self,
4343
*,
44-
messages: list[ChatCompletionMessageParam],
44+
messages: list[ResponseInputItemParam],
4545
overrides: ChatRequestOverrides,
4646
searcher: PostgresSearcher,
4747
openai_chat_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
@@ -109,34 +109,17 @@ async def search_database(
109109
)
110110

111111
async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]:
112-
few_shots = json.loads(self.query_fewshots)
113-
few_shot_inputs = []
114-
for few_shot in few_shots:
115-
if few_shot["role"] == "user":
116-
message = EasyInputMessageParam(role="user", content=few_shot["content"])
117-
elif few_shot["role"] == "assistant" and few_shot["tool_calls"] is not None:
118-
message = ResponseFunctionToolCallParam(
119-
id="madeup",
120-
call_id=few_shot["tool_calls"][0]["id"],
121-
name=few_shot["tool_calls"][0]["function"]["name"],
122-
arguments=few_shot["tool_calls"][0]["function"]["arguments"],
123-
type="function_call",
124-
)
125-
elif few_shot["role"] == "tool" and few_shot["tool_call_id"] is not None:
126-
message = FunctionCallOutput(
127-
id="madeupoutput",
128-
call_id=few_shot["tool_call_id"],
129-
output=few_shot["content"],
130-
type="function_call_output",
131-
)
132-
few_shot_inputs.append(message)
133-
112+
few_shots: list[ResponseInputItemParam] = json.loads(self.query_fewshots)
134113
user_query = f"Find search results for user query: {self.chat_params.original_user_query}"
135114
new_user_message = EasyInputMessageParam(role="user", content=user_query)
136-
all_messages = few_shot_inputs + self.chat_params.past_messages + [new_user_message]
115+
all_messages = few_shots + self.chat_params.past_messages + [new_user_message]
137116

138117
run_results = await Runner.run(self.search_agent, input=all_messages)
139-
search_results = run_results.new_items[-1].output
118+
most_recent_response = run_results.new_items[-1]
119+
if isinstance(most_recent_response, ToolCallOutputItem):
120+
search_results = most_recent_response.output
121+
else:
122+
raise ValueError("Error retrieving search results, model did not call tool properly")
140123

141124
thoughts = [
142125
ThoughtStep(

src/backend/fastapi_app/rag_base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from abc import ABC, abstractmethod
33
from collections.abc import AsyncGenerator
44

5-
from openai.types.chat import ChatCompletionMessageParam
5+
from openai.types.responses import ResponseInputItemParam
66

77
from fastapi_app.api_models import (
88
ChatParams,
@@ -18,16 +18,14 @@ class RAGChatBase(ABC):
1818
prompts_dir = pathlib.Path(__file__).parent / "prompts/"
1919
answer_prompt_template = open(prompts_dir / "answer.txt").read()
2020

21-
def get_chat_params(
22-
self, messages: list[ChatCompletionMessageParam], overrides: ChatRequestOverrides
23-
) -> ChatParams:
21+
def get_chat_params(self, messages: list[ResponseInputItemParam], overrides: ChatRequestOverrides) -> ChatParams:
2422
response_token_limit = 1024
2523
prompt_template = overrides.prompt_template or self.answer_prompt_template
2624

2725
enable_text_search = overrides.retrieval_mode in ["text", "hybrid", None]
2826
enable_vector_search = overrides.retrieval_mode in ["vectors", "hybrid", None]
2927

30-
original_user_query = messages[-1]["content"]
28+
original_user_query = messages[-1].get("content")
3129
if not isinstance(original_user_query, str):
3230
raise ValueError("The most recent message content must be a string.")
3331

src/backend/fastapi_app/rag_simple.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33

44
from agents import Agent, ModelSettings, OpenAIChatCompletionsModel, Runner, set_tracing_disabled
55
from openai import AsyncAzureOpenAI, AsyncOpenAI
6-
from openai.types.chat import ChatCompletionMessageParam
7-
from openai.types.responses import ResponseTextDeltaEvent
6+
from openai.types.responses import ResponseInputItemParam, ResponseTextDeltaEvent
87

98
from fastapi_app.api_models import (
109
AIChatRoles,
@@ -26,7 +25,7 @@ class SimpleRAGChat(RAGChatBase):
2625
def __init__(
2726
self,
2827
*,
29-
messages: list[ChatCompletionMessageParam],
28+
messages: list[ResponseInputItemParam],
3029
overrides: ChatRequestOverrides,
3130
searcher: PostgresSearcher,
3231
openai_chat_client: Union[AsyncOpenAI, AsyncAzureOpenAI],

0 commit comments

Comments
 (0)