From fcc9dd22a2bc2729b187a708394aeee37bcd2d86 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Thu, 8 May 2025 18:44:52 +0000 Subject: [PATCH] Change to output type --- src/backend/fastapi_app/__init__.py | 2 +- src/backend/fastapi_app/api_models.py | 6 +++ src/backend/fastapi_app/rag_advanced.py | 49 ++++++++++++------------- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/src/backend/fastapi_app/__init__.py b/src/backend/fastapi_app/__init__.py index b760fdb2..4f2fd484 100644 --- a/src/backend/fastapi_app/__init__.py +++ b/src/backend/fastapi_app/__init__.py @@ -58,7 +58,7 @@ def create_app(testing: bool = False): else: if not testing: load_dotenv(override=True) - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.DEBUG) # Turn off particularly noisy INFO level logs from Azure Core SDK: logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING) diff --git a/src/backend/fastapi_app/api_models.py b/src/backend/fastapi_app/api_models.py index 46574c4e..8b5af1f0 100644 --- a/src/backend/fastapi_app/api_models.py +++ b/src/backend/fastapi_app/api_models.py @@ -117,6 +117,12 @@ class BrandFilter(Filter): value: str = Field(description="The brand name to compare against (e.g., 'AirStrider')") +class SearchArguments(BaseModel): + search_query: str + price_filter: Optional[PriceFilter] = Field(default=None) + brand_filter: Optional[BrandFilter] = Field(default=None) + + class SearchResults(BaseModel): query: str """The original search query""" diff --git a/src/backend/fastapi_app/rag_advanced.py b/src/backend/fastapi_app/rag_advanced.py index 3541d8c7..10406351 100644 --- a/src/backend/fastapi_app/rag_advanced.py +++ b/src/backend/fastapi_app/rag_advanced.py @@ -3,7 +3,7 @@ from openai import AsyncAzureOpenAI, AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam -from pydantic_ai import Agent, RunContext +from pydantic_ai import Agent from pydantic_ai.messages import ModelMessagesTypeAdapter from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider @@ -11,15 +11,14 @@ from fastapi_app.api_models import ( AIChatRoles, - BrandFilter, ChatRequestOverrides, Filter, ItemPublic, Message, - PriceFilter, RAGContext, RetrievalResponse, RetrievalResponseDelta, + SearchArguments, SearchResults, ThoughtStep, ) @@ -59,7 +58,7 @@ def __init__( ), system_prompt=self.query_prompt_template, tools=[self.search_database], - output_type=SearchResults, + output_type=SearchArguments, ) self.answer_agent = Agent( pydantic_chat_model, @@ -73,10 +72,7 @@ def __init__( async def search_database( self, - ctx: RunContext[ChatParams], - search_query: str, - price_filter: Optional[PriceFilter] = None, - brand_filter: Optional[BrandFilter] = None, + search_arguments: SearchArguments, ) -> SearchResults: """ Search PostgreSQL database for relevant products based on user query @@ -91,52 +87,55 @@ async def search_database( """ # Only send non-None filters filters: list[Filter] = [] - if price_filter: - filters.append(price_filter) - if brand_filter: - filters.append(brand_filter) + if search_arguments.price_filter: + filters.append(search_arguments.price_filter) + if search_arguments.brand_filter: + filters.append(search_arguments.brand_filter) results = await self.searcher.search_and_embed( - search_query, - top=ctx.deps.top, - enable_vector_search=ctx.deps.enable_vector_search, - enable_text_search=ctx.deps.enable_text_search, + search_arguments.search_query, + top=self.chat_params.top, + enable_vector_search=self.chat_params.enable_vector_search, + enable_text_search=self.chat_params.enable_text_search, filters=filters, ) return SearchResults( - query=search_query, items=[ItemPublic.model_validate(item.to_dict()) for item in results], filters=filters + query=search_arguments.search_query, + items=[ItemPublic.model_validate(item.to_dict()) for item in results], + filters=filters, ) async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]: few_shots = ModelMessagesTypeAdapter.validate_json(self.query_fewshots) user_query = f"Find search results for user query: {self.chat_params.original_user_query}" - results = await self.search_agent.run( + search_agent_runner = await self.search_agent.run( user_query, message_history=few_shots + self.chat_params.past_messages, - deps=self.chat_params, + output_type=SearchArguments, ) - items = results.output.items + search_arguments = search_agent_runner.output + search_results = await self.search_database(search_arguments=search_arguments) thoughts = [ ThoughtStep( title="Prompt to generate search arguments", - description=results.all_messages(), + description=search_agent_runner.all_messages(), props=self.model_for_thoughts, ), ThoughtStep( title="Search using generated search arguments", - description=results.output.query, + description=search_results.query, props={ "top": self.chat_params.top, "vector_search": self.chat_params.enable_vector_search, "text_search": self.chat_params.enable_text_search, - "filters": results.output.filters, + "filters": search_results.filters, }, ), ThoughtStep( title="Search results", - description=items, + description=search_results.items, ), ] - return items, thoughts + return search_results.items, thoughts async def answer( self,