|
| 1 | +import json |
| 2 | +import logging |
| 3 | +from collections.abc import AsyncGenerator |
| 4 | + |
1 | 5 | import fastapi
|
2 | 6 | from fastapi import HTTPException
|
| 7 | +from fastapi.responses import StreamingResponse |
3 | 8 | from sqlalchemy import select
|
4 | 9 |
|
5 |
| -from fastapi_app.api_models import ChatRequest, ItemPublic, ItemWithDistance, RetrievalResponse |
| 10 | +from fastapi_app.api_models import ChatRequest, ItemPublic, ItemWithDistance, Message, RetrievalResponse |
6 | 11 | from fastapi_app.dependencies import ChatClient, CommonDeps, DBSession, EmbeddingsClient
|
7 | 12 | from fastapi_app.postgres_models import Item
|
8 | 13 | from fastapi_app.postgres_searcher import PostgresSearcher
|
|
12 | 17 | router = fastapi.APIRouter()
|
13 | 18 |
|
14 | 19 |
|
| 20 | +async def format_as_ndjson(r: AsyncGenerator[RetrievalResponse | Message, None]) -> AsyncGenerator[str, None]: |
| 21 | + """ |
| 22 | + Format the response as NDJSON |
| 23 | + """ |
| 24 | + try: |
| 25 | + async for event in r: |
| 26 | + yield json.dumps(event.model_dump(), ensure_ascii=False) + "\n" |
| 27 | + except Exception as error: |
| 28 | + logging.exception("Exception while generating response stream: %s", error) |
| 29 | + yield json.dumps({"error": str(error)}, ensure_ascii=False) + "\n" |
| 30 | + |
| 31 | + |
15 | 32 | @router.get("/items/{id}", response_model=ItemPublic)
|
16 | 33 | async def item_handler(database_session: DBSession, id: int) -> ItemPublic:
|
17 | 34 | """A simple API to get an item by ID."""
|
@@ -96,3 +113,39 @@ async def chat_handler(
|
96 | 113 |
|
97 | 114 | response = await run_ragchat(chat_request.messages, overrides=overrides)
|
98 | 115 | return response
|
| 116 | + |
| 117 | + |
| 118 | +@router.post("/chat/stream") |
| 119 | +async def chat_stream_handler( |
| 120 | + context: CommonDeps, |
| 121 | + database_session: DBSession, |
| 122 | + openai_embed: EmbeddingsClient, |
| 123 | + openai_chat: ChatClient, |
| 124 | + chat_request: ChatRequest, |
| 125 | +): |
| 126 | + overrides = chat_request.context.get("overrides", {}) |
| 127 | + |
| 128 | + searcher = PostgresSearcher( |
| 129 | + db_session=database_session, |
| 130 | + openai_embed_client=openai_embed.client, |
| 131 | + embed_deployment=context.openai_embed_deployment, |
| 132 | + embed_model=context.openai_embed_model, |
| 133 | + embed_dimensions=context.openai_embed_dimensions, |
| 134 | + ) |
| 135 | + if overrides.get("use_advanced_flow"): |
| 136 | + run_ragchat = AdvancedRAGChat( |
| 137 | + searcher=searcher, |
| 138 | + openai_chat_client=openai_chat.client, |
| 139 | + chat_model=context.openai_chat_model, |
| 140 | + chat_deployment=context.openai_chat_deployment, |
| 141 | + ).run_stream |
| 142 | + else: |
| 143 | + run_ragchat = SimpleRAGChat( |
| 144 | + searcher=searcher, |
| 145 | + openai_chat_client=openai_chat.client, |
| 146 | + chat_model=context.openai_chat_model, |
| 147 | + chat_deployment=context.openai_chat_deployment, |
| 148 | + ).run_stream |
| 149 | + |
| 150 | + result = run_ragchat(chat_request.messages, overrides=overrides) |
| 151 | + return StreamingResponse(content=format_as_ndjson(result), media_type="application/x-ndjson") |
0 commit comments