Skip to content

Commit bbd4bad

Browse files
committed
implement chat/stream/ endpoint and fix empty choices error
1 parent 7e11373 commit bbd4bad

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

src/backend/fastapi_app/rag_simple.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,5 +232,7 @@ async def run_stream(
232232
),
233233
)
234234
async for response_chunk in chat_completion_async_stream:
235-
yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant")
235+
# first response has empty choices
236+
if response_chunk.choices:
237+
yield Message(content=str(response_chunk.choices[0].delta.content), role="assistant")
236238
return

src/backend/fastapi_app/routes/api_routes.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
import json
2+
import logging
3+
from collections.abc import AsyncGenerator
4+
15
import fastapi
26
from fastapi import HTTPException
7+
from fastapi.responses import StreamingResponse
38
from sqlalchemy import select
49

5-
from fastapi_app.api_models import ChatRequest, ItemPublic, ItemWithDistance, RetrievalResponse
10+
from fastapi_app.api_models import ChatRequest, ItemPublic, ItemWithDistance, Message, RetrievalResponse
611
from fastapi_app.dependencies import ChatClient, CommonDeps, DBSession, EmbeddingsClient
712
from fastapi_app.postgres_models import Item
813
from fastapi_app.postgres_searcher import PostgresSearcher
@@ -12,6 +17,18 @@
1217
router = fastapi.APIRouter()
1318

1419

20+
async def format_as_ndjson(r: AsyncGenerator[RetrievalResponse | Message, None]) -> AsyncGenerator[str, None]:
21+
"""
22+
Format the response as NDJSON
23+
"""
24+
try:
25+
async for event in r:
26+
yield json.dumps(event.model_dump(), ensure_ascii=False) + "\n"
27+
except Exception as error:
28+
logging.exception("Exception while generating response stream: %s", error)
29+
yield json.dumps({"error": str(error)}, ensure_ascii=False) + "\n"
30+
31+
1532
@router.get("/items/{id}", response_model=ItemPublic)
1633
async def item_handler(database_session: DBSession, id: int) -> ItemPublic:
1734
"""A simple API to get an item by ID."""
@@ -96,3 +113,39 @@ async def chat_handler(
96113

97114
response = await run_ragchat(chat_request.messages, overrides=overrides)
98115
return response
116+
117+
118+
@router.post("/chat/stream")
119+
async def chat_stream_handler(
120+
context: CommonDeps,
121+
database_session: DBSession,
122+
openai_embed: EmbeddingsClient,
123+
openai_chat: ChatClient,
124+
chat_request: ChatRequest,
125+
):
126+
overrides = chat_request.context.get("overrides", {})
127+
128+
searcher = PostgresSearcher(
129+
db_session=database_session,
130+
openai_embed_client=openai_embed.client,
131+
embed_deployment=context.openai_embed_deployment,
132+
embed_model=context.openai_embed_model,
133+
embed_dimensions=context.openai_embed_dimensions,
134+
)
135+
if overrides.get("use_advanced_flow"):
136+
run_ragchat = AdvancedRAGChat(
137+
searcher=searcher,
138+
openai_chat_client=openai_chat.client,
139+
chat_model=context.openai_chat_model,
140+
chat_deployment=context.openai_chat_deployment,
141+
).run_stream
142+
else:
143+
run_ragchat = SimpleRAGChat(
144+
searcher=searcher,
145+
openai_chat_client=openai_chat.client,
146+
chat_model=context.openai_chat_model,
147+
chat_deployment=context.openai_chat_deployment,
148+
).run_stream
149+
150+
result = run_ragchat(chat_request.messages, overrides=overrides)
151+
return StreamingResponse(content=format_as_ndjson(result), media_type="application/x-ndjson")

0 commit comments

Comments
 (0)