Skip to content

Commit 202fa4b

Browse files
committed
More Pydantic AI changes
1 parent b1b8746 commit 202fa4b

File tree

3 files changed

+39
-33
lines changed

3 files changed

+39
-33
lines changed

src/backend/fastapi_app/api_models.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,34 @@ class ChatRequest(BaseModel):
4141
sessionState: Optional[Any] = None
4242

4343

44+
class ItemPublic(BaseModel):
45+
id: int
46+
type: str
47+
brand: str
48+
name: str
49+
description: str
50+
price: float
51+
52+
def to_str_for_rag(self):
53+
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"
54+
55+
56+
class ItemWithDistance(ItemPublic):
57+
distance: float
58+
59+
def __init__(self, **data):
60+
super().__init__(**data)
61+
self.distance = round(self.distance, 2)
62+
63+
4464
class ThoughtStep(BaseModel):
4565
title: str
4666
description: Any
4767
props: dict = {}
4868

4969

5070
class RAGContext(BaseModel):
51-
data_points: dict[int, dict[str, Any]]
71+
data_points: dict[int, ItemPublic]
5272
thoughts: list[ThoughtStep]
5373
followup_questions: Optional[list[str]] = None
5474

@@ -69,26 +89,6 @@ class RetrievalResponseDelta(BaseModel):
6989
sessionState: Optional[Any] = None
7090

7191

72-
class ItemPublic(BaseModel):
73-
id: int
74-
type: str
75-
brand: str
76-
name: str
77-
description: str
78-
price: float
79-
80-
def to_str_for_rag(self):
81-
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"
82-
83-
84-
class ItemWithDistance(ItemPublic):
85-
distance: float
86-
87-
def __init__(self, **data):
88-
super().__init__(**data)
89-
self.distance = round(self.distance, 2)
90-
91-
9292
class ChatParams(ChatRequestOverrides):
9393
prompt_template: str
9494
response_token_limit: int = 1024

src/backend/fastapi_app/openai_clients.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ async def create_openai_chat_client(
1414
openai_chat_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
1515
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
1616
if OPENAI_CHAT_HOST == "azure":
17-
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-03-01-preview"
17+
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-10-21"
1818
azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
1919
azure_deployment = os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"]
2020
if api_key := os.getenv("AZURE_OPENAI_KEY"):

src/backend/fastapi_app/rag_advanced.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class BrandFilter(TypedDict):
5252

5353

5454
class SearchResults(TypedDict):
55+
query: str
56+
"""The original search query"""
57+
5558
items: list[ItemPublic]
5659
"""List of items that match the search query and filters"""
5760

@@ -105,7 +108,9 @@ async def search_database(
105108
enable_text_search=ctx.deps.enable_text_search,
106109
filters=filters,
107110
)
108-
return SearchResults(items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters)
111+
return SearchResults(
112+
query=search_query, items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters
113+
)
109114

110115
async def prepare_context(self, chat_params: ChatParams) -> tuple[list[ItemPublic], list[ThoughtStep]]:
111116
model = OpenAIModel(
@@ -119,35 +124,36 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[list[ItemPubli
119124
output_type=SearchResults,
120125
)
121126
# TODO: Provide few-shot examples
127+
user_query = f"Find search results for user query: {chat_params.original_user_query}"
122128
results = await agent.run(
123-
f"Find search results for user query: {chat_params.original_user_query}",
124-
# message_history=chat_params.past_messages, # TODO
129+
user_query,
130+
message_history=chat_params.past_messages,
125131
deps=chat_params,
126132
)
127-
items = results.output.items
133+
items = results.output["items"]
128134
thoughts = [
129135
ThoughtStep(
130136
title="Prompt to generate search arguments",
131-
description=chat_params.past_messages, # TODO: update this
137+
description=results.all_messages(),
132138
props=(
133139
{"model": self.chat_model, "deployment": self.chat_deployment}
134140
if self.chat_deployment
135-
else {"model": self.chat_model}
141+
else {"model": self.chat_model} # TODO
136142
),
137143
),
138144
ThoughtStep(
139145
title="Search using generated search arguments",
140-
description=chat_params.original_user_query, # TODO:
146+
description=results.output["query"],
141147
props={
142148
"top": chat_params.top,
143149
"vector_search": chat_params.enable_vector_search,
144150
"text_search": chat_params.enable_text_search,
145-
"filters": [], # TODO
151+
"filters": results.output["filters"],
146152
},
147153
),
148154
ThoughtStep(
149155
title="Search results",
150-
description="", # TODO
156+
description=items,
151157
),
152158
]
153159
return items, thoughts
@@ -178,12 +184,12 @@ async def answer(
178184
return RetrievalResponse(
179185
message=Message(content=str(response.output), role=AIChatRoles.ASSISTANT),
180186
context=RAGContext(
181-
data_points={}, # TODO
187+
data_points={item.id: item for item in items},
182188
thoughts=earlier_thoughts
183189
+ [
184190
ThoughtStep(
185191
title="Prompt to generate answer",
186-
description="", # TODO: update
192+
description=response.all_messages(),
187193
props=(
188194
{"model": self.chat_model, "deployment": self.chat_deployment}
189195
if self.chat_deployment

0 commit comments

Comments
 (0)