13
13
from fastapi .templating import Jinja2Templates
14
14
15
15
from azure .ai .projects .aio import AIProjectClient
16
+ from fastapi .responses import JSONResponse
16
17
from azure .ai .projects .models import (
17
18
Agent ,
18
19
MessageDeltaChunk ,
19
20
ThreadMessage ,
20
21
ThreadRun ,
21
22
AsyncAgentEventHandler ,
23
+ OpenAIPageableListOfThreadMessage ,
24
+ MessageTextContent ,
25
+ MessageTextFileCitationAnnotation ,
26
+ MessageTextUrlCitationAnnotation ,
22
27
RunStep
23
28
)
24
29
@@ -47,7 +52,28 @@ def get_agent(request: Request) -> Agent:
47
52
def serialize_sse_event (data : Dict ) -> str :
48
53
return f"data: { json .dumps (data )} \n \n "
49
54
50
-
55
+ async def get_message_and_annotations (ai_client : AIProjectClient , message : ThreadMessage ) -> Dict :
56
+ annotations = []
57
+ # Get file annotations for the file search.
58
+ for annotation in (a .as_dict () for a in message .file_citation_annotations ):
59
+ file_id = annotation ["file_citation" ]["file_id" ]
60
+ logger .info (f"Fetching file with ID for annotation { file_id } " )
61
+ openai_file = await ai_client .agents .get_file (file_id )
62
+ annotation ["file_name" ] = openai_file .filename
63
+ logger .info (f"File name for annotation: { annotation ['file_name' ]} " )
64
+ annotations .append (annotation )
65
+
66
+ # Get url annotation for the index search.
67
+ for url_annotation in message .url_citation_annotations :
68
+ annotation = url_annotation .as_dict ()
69
+ annotation ["file_name" ] = annotation ['url_citation' ]['title' ]
70
+ logger .info (f"File name for annotation: { annotation ['file_name' ]} " )
71
+ annotations .append (annotation )
72
+
73
+ return {
74
+ 'content' : message .text_messages [0 ].text .value ,
75
+ 'annotations' : annotations
76
+ }
51
77
class MyEventHandler (AsyncAgentEventHandler [str ]):
52
78
def __init__ (self , ai_client : AIProjectClient ):
53
79
super ().__init__ ()
@@ -64,28 +90,9 @@ async def on_thread_message(self, message: ThreadMessage) -> Optional[str]:
64
90
return None
65
91
66
92
logger .info ("MyEventHandler: Received completed message" )
67
- annotations = []
68
- # Get file annotations for the file search.
69
- for annotation in (a .as_dict () for a in message .file_citation_annotations ):
70
- file_id = annotation ["file_citation" ]["file_id" ]
71
- logger .info (f"Fetching file with ID for annotation { file_id } " )
72
- openai_file = await self .ai_client .agents .get_file (file_id )
73
- annotation ["file_name" ] = openai_file .filename
74
- logger .info (f"File name for annotation: { annotation ['file_name' ]} " )
75
- annotations .append (annotation )
76
-
77
- # Get url annotation for the index search.
78
- for url_annotation in message .url_citation_annotations :
79
- annotation = url_annotation .as_dict ()
80
- annotation ["file_name" ] = annotation ['url_citation' ]['title' ]
81
- logger .info (f"File name for annotation: { annotation ['file_name' ]} " )
82
- annotations .append (annotation )
83
-
84
- stream_data = {
85
- 'content' : message .text_messages [0 ].text .value ,
86
- 'annotations' : annotations ,
87
- 'type' : "completed_message"
88
- }
93
+
94
+ stream_data = await get_message_and_annotations (self .ai_client , message )
95
+ stream_data ['type' ] = "completed_message"
89
96
return serialize_sse_event (stream_data )
90
97
except Exception as e :
91
98
logger .error (f"Error in event handler for thread message: { e } " , exc_info = True )
@@ -150,6 +157,52 @@ async def get_result(thread_id: str, agent_id: str, ai_client : AIProjectClient)
150
157
yield serialize_sse_event ({'type' : "error" , 'message' : str (e )})
151
158
152
159
160
+ @router .get ("/chat/history" )
161
+ async def history (
162
+ request : Request ,
163
+ ai_client : AIProjectClient = Depends (get_ai_client ),
164
+ agent : Agent = Depends (get_agent ),
165
+ ):
166
+ # Retrieve the thread ID from the cookies (if available).
167
+ thread_id = request .cookies .get ('thread_id' )
168
+ agent_id = request .cookies .get ('agent_id' )
169
+
170
+ # Attempt to get an existing thread. If not found, create a new one.
171
+ try :
172
+ if thread_id and agent_id == agent .id :
173
+ logger .info (f"Retrieving thread with ID { thread_id } " )
174
+ thread = await ai_client .agents .get_thread (thread_id )
175
+ else :
176
+ logger .info ("Creating a new thread" )
177
+ thread = await ai_client .agents .create_thread ()
178
+ except Exception as e :
179
+ logger .error (f"Error handling thread: { e } " )
180
+ raise HTTPException (status_code = 400 , detail = f"Error handling thread: { e } " )
181
+
182
+ thread_id = thread .id
183
+ messages = OpenAIPageableListOfThreadMessage ()
184
+
185
+ # Create a new message from the user's input.
186
+ try :
187
+ content = []
188
+ response = await ai_client .agents .list_messages (
189
+ thread_id = thread_id ,
190
+ )
191
+ for message in response .data :
192
+ content .append (await get_message_and_annotations (ai_client , message ))
193
+
194
+ logger .info (f"List message, thread ID: { thread_id } " )
195
+ response = JSONResponse (content = content )
196
+
197
+ # Update cookies to persist the thread and agent IDs.
198
+ response .set_cookie ("thread_id" , thread_id )
199
+ response .set_cookie ("agent_id" , agent_id )
200
+ return response
201
+ except Exception as e :
202
+ logger .error (f"Error listing message: { e } " )
203
+ raise HTTPException (status_code = 500 , detail = f"Error list message: { e } " )
204
+
205
+
153
206
@router .post ("/chat" )
154
207
async def chat (
155
208
request : Request ,
0 commit comments