Skip to content

Commit b1b8746

Browse files
committed
More Pydantic-AI usage
1 parent c4d2a7f commit b1b8746

File tree

4 files changed

+31
-15
lines changed

4 files changed

+31
-15
lines changed

src/backend/fastapi_app/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ class State(TypedDict):
3434
@asynccontextmanager
3535
async def lifespan(app: fastapi.FastAPI) -> AsyncIterator[State]:
3636
context = await common_parameters()
37-
azure_credential = await get_azure_credential()
37+
azure_credential = None
38+
if (
39+
os.getenv("OPENAI_CHAT_HOST") == "azure"
40+
or os.getenv("OPENAI_EMBED_HOST") == "azure"
41+
or os.getenv("POSTGRES_HOST").endswith(".database.azure.com")
42+
):
43+
azure_credential = await get_azure_credential()
3844
engine = await create_postgres_engine_from_env(azure_credential)
3945
sessionmaker = await create_async_sessionmaker(engine)
4046
chat_client = await create_openai_chat_client(azure_credential)

src/backend/fastapi_app/api_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ class ItemPublic(BaseModel):
7777
description: str
7878
price: float
7979

80+
def to_str_for_rag(self):
81+
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"
82+
8083

8184
class ItemWithDistance(ItemPublic):
8285
distance: float

src/backend/fastapi_app/rag_advanced.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from fastapi_app.api_models import (
1414
AIChatRoles,
15+
ItemPublic,
1516
Message,
1617
RAGContext,
1718
RetrievalResponse,
@@ -50,6 +51,14 @@ class BrandFilter(TypedDict):
5051
"""The brand name to compare against (e.g., 'AirStrider')"""
5152

5253

54+
class SearchResults(TypedDict):
55+
items: list[ItemPublic]
56+
"""List of items that match the search query and filters"""
57+
58+
filters: list[Union[PriceFilter, BrandFilter]]
59+
"""List of filters applied to the search results"""
60+
61+
5362
class AdvancedRAGChat(RAGChatBase):
5463
def __init__(
5564
self,
@@ -71,7 +80,7 @@ async def search_database(
7180
search_query: str,
7281
price_filter: Optional[PriceFilter] = None,
7382
brand_filter: Optional[BrandFilter] = None,
74-
) -> list[str]:
83+
) -> SearchResults:
7584
"""
7685
Search PostgreSQL database for relevant products based on user query
7786
@@ -83,7 +92,6 @@ async def search_database(
8392
Returns:
8493
List of formatted items that match the search query and filters
8594
"""
86-
print(search_query, price_filter, brand_filter)
8795
# Only send non-None filters
8896
filters = []
8997
if price_filter:
@@ -97,9 +105,9 @@ async def search_database(
97105
enable_text_search=ctx.deps.enable_text_search,
98106
filters=filters,
99107
)
100-
return [f"[{(item.id)}]:{item.to_str_for_rag()}\n\n" for item in results]
108+
return SearchResults(items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters)
101109

102-
async def prepare_context(self, chat_params: ChatParams) -> tuple[str, list[Item], list[ThoughtStep]]:
110+
async def prepare_context(self, chat_params: ChatParams) -> tuple[list[ItemPublic], list[ThoughtStep]]:
103111
model = OpenAIModel(
104112
os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=self.openai_chat_client)
105113
)
@@ -108,17 +116,15 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[str, list[Item
108116
model_settings=ModelSettings(temperature=0.0, max_tokens=500, seed=chat_params.seed),
109117
system_prompt=self.query_prompt_template,
110118
tools=[self.search_database],
111-
output_type=list[str],
119+
output_type=SearchResults,
112120
)
113121
# TODO: Provide few-shot examples
114122
results = await agent.run(
115123
f"Find search results for user query: {chat_params.original_user_query}",
116124
# message_history=chat_params.past_messages, # TODO
117125
deps=chat_params,
118126
)
119-
if not isinstance(results, list):
120-
raise ValueError("Search results should be a list of strings")
121-
127+
items = results.output.items
122128
thoughts = [
123129
ThoughtStep(
124130
title="Prompt to generate search arguments",
@@ -144,12 +150,12 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[str, list[Item
144150
description="", # TODO
145151
),
146152
]
147-
return results, thoughts
153+
return items, thoughts
148154

149155
async def answer(
150156
self,
151157
chat_params: ChatParams,
152-
results: list[str],
158+
items: list[ItemPublic],
153159
earlier_thoughts: list[ThoughtStep],
154160
) -> RetrievalResponse:
155161
agent = Agent(
@@ -163,15 +169,16 @@ async def answer(
163169
),
164170
)
165171

172+
item_references = [item.to_str_for_rag() for item in items]
166173
response = await agent.run(
167-
user_prompt=chat_params.original_user_query + "Sources:\n" + "\n".join(results),
174+
user_prompt=chat_params.original_user_query + "Sources:\n" + "\n".join(item_references),
168175
message_history=chat_params.past_messages,
169176
)
170177

171178
return RetrievalResponse(
172179
message=Message(content=str(response.output), role=AIChatRoles.ASSISTANT),
173180
context=RAGContext(
174-
data_points={item.id: item.to_dict() for item in results},
181+
data_points={}, # TODO
175182
thoughts=earlier_thoughts
176183
+ [
177184
ThoughtStep(

src/backend/fastapi_app/routes/api_routes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ async def chat_handler(
136136

137137
chat_params = rag_flow.get_params(chat_request.messages, chat_request.context.overrides)
138138

139-
results, thoughts = await rag_flow.prepare_context(chat_params)
140-
response = await rag_flow.answer(chat_params=chat_params, results=results, earlier_thoughts=thoughts)
139+
items, thoughts = await rag_flow.prepare_context(chat_params)
140+
response = await rag_flow.answer(chat_params=chat_params, items=items, earlier_thoughts=thoughts)
141141
return response
142142
except Exception as e:
143143
if isinstance(e, APIError) and e.code == "content_filter":

0 commit comments

Comments
 (0)