|
1 | 1 | import importlib
|
2 | 2 | import json
|
3 | 3 | import logging
|
4 |
| -import uuid |
5 | 4 | from collections.abc import AsyncGenerator
|
6 | 5 | from pathlib import Path
|
7 | 6 | from typing import Any, Literal
|
|
15 | 14 | from pydantic import BaseModel, Field
|
16 | 15 |
|
17 | 16 | from ragbits.chat.interface import ChatInterface
|
18 |
| -from ragbits.chat.interface.types import ChatResponse, Message |
| 17 | +from ragbits.chat.interface.types import ChatContext, ChatResponse, Message |
19 | 18 |
|
20 | 19 | logger = logging.getLogger(__name__)
|
21 | 20 |
|
@@ -112,33 +111,31 @@ async def chat_message(request: ChatMessageRequest) -> StreamingResponse:
|
112 | 111 | if not self.chat_interface:
|
113 | 112 | raise HTTPException(status_code=500, detail="Chat implementation is not initialized")
|
114 | 113 |
|
115 |
| - # Generate a unique message ID for this conversation message |
116 |
| - message_id = str(uuid.uuid4()) |
| 114 | + # Convert request context to ChatContext |
| 115 | + chat_context = ChatContext(**request.context) |
117 | 116 |
|
118 | 117 | # Verify state signature if provided
|
119 | 118 | if "state" in request.context and "signature" in request.context:
|
120 | 119 | state = request.context["state"]
|
121 | 120 | signature = request.context["signature"]
|
122 | 121 | if not ChatInterface.verify_state(state, signature):
|
123 |
| - logger.warning(f"Invalid state signature received for message {message_id}") |
| 122 | + logger.warning(f"Invalid state signature received for message {chat_context.message_id}") |
124 | 123 | raise HTTPException(
|
125 | 124 | status_code=status.HTTP_400_BAD_REQUEST,
|
126 | 125 | detail="Invalid state signature",
|
127 | 126 | )
|
128 |
| - # Remove the signature from context after verification |
129 |
| - del request.context["signature"] |
130 |
| - # Ensure context has a state field if not present |
131 |
| - elif "state" not in request.context: |
132 |
| - request.context["state"] = {} |
| 127 | + # Remove the signature from context after verification (it's already parsed into ChatContext) |
133 | 128 |
|
134 | 129 | # Get the response generator from the chat interface
|
135 | 130 | response_generator = self.chat_interface.chat(
|
136 |
| - message=request.message, history=[msg.model_dump() for msg in request.history], context=request.context |
| 131 | + message=request.message, |
| 132 | + history=[msg.model_dump() for msg in request.history], |
| 133 | + context=chat_context, |
137 | 134 | )
|
138 | 135 |
|
139 | 136 | # Pass the generator to the SSE formatter
|
140 | 137 | return StreamingResponse(
|
141 |
| - RagbitsAPI._chat_response_to_sse(response_generator, message_id, self.chat_interface), |
| 138 | + RagbitsAPI._chat_response_to_sse(response_generator), |
142 | 139 | media_type="text/event-stream",
|
143 | 140 | )
|
144 | 141 |
|
@@ -179,46 +176,23 @@ async def config() -> JSONResponse:
|
179 | 176 |
|
180 | 177 | @staticmethod
|
181 | 178 | async def _chat_response_to_sse(
|
182 |
| - responses: AsyncGenerator[ChatResponse], message_id: str, chat_interface: ChatInterface | None = None |
| 179 | + responses: AsyncGenerator[ChatResponse], |
183 | 180 | ) -> AsyncGenerator[str, None]:
|
184 | 181 | """
|
185 | 182 | Formats chat responses into Server-Sent Events (SSE) format for streaming to the client.
|
186 | 183 | Each response is converted to JSON and wrapped in the SSE 'data:' prefix.
|
187 | 184 |
|
188 | 185 | Args:
|
189 | 186 | responses: The chat response generator
|
190 |
| - message_id: The unique identifier for this message |
191 |
| - chat_interface: The chat interface instance to use for verifying state (optional) |
192 | 187 | """
|
193 |
| - # Send the message_id as the first SSE event |
194 |
| - data = json.dumps({"type": "message_id", "content": message_id}) |
195 |
| - yield f"data: {data}\n\n" |
196 |
| - |
197 | 188 | async for response in responses:
|
198 |
| - if response.type.value == "state_update": |
199 |
| - state_update = response.as_state_update() |
200 |
| - if state_update: |
201 |
| - # Verification is already done by the chat interface that created the state update |
202 |
| - data = json.dumps( |
203 |
| - { |
204 |
| - "type": "state_update", |
205 |
| - "content": { |
206 |
| - "state": state_update.state, |
207 |
| - "signature": state_update.signature, |
208 |
| - }, |
209 |
| - } |
210 |
| - ) |
211 |
| - yield f"data: {data}\n\n" |
212 |
| - else: |
213 |
| - data = json.dumps( |
214 |
| - { |
215 |
| - "type": response.type.value, |
216 |
| - "content": response.content |
217 |
| - if isinstance(response.content, str) |
218 |
| - else response.content.model_dump(), |
219 |
| - } |
220 |
| - ) |
221 |
| - yield f"data: {data}\n\n" |
| 189 | + data = json.dumps( |
| 190 | + { |
| 191 | + "type": response.type.value, |
| 192 | + "content": response.content if isinstance(response.content, str) else response.content.model_dump(), |
| 193 | + } |
| 194 | + ) |
| 195 | + yield f"data: {data}\n\n" |
222 | 196 |
|
223 | 197 | @staticmethod
|
224 | 198 | def _load_chat_interface(implementation: type[ChatInterface] | str) -> ChatInterface:
|
|
0 commit comments