Skip to content

Commit 6c5d7b8

Browse files
authored
fix: Resolve WebSocket issues and add real-time call example (#160)
- Add support for chat update events in WebSocket communication - Introduce new real-time audio chat GUI example application
1 parent da63a6d commit 6c5d7b8

File tree

7 files changed

+472
-12
lines changed

7 files changed

+472
-12
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ dist/
99
scripts/
1010
.cache/
1111
output.wav
12+
response.wav
13+
temp_response.pcm

cozepy/chat/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def build_assistant_answer(content: str, meta_data: Optional[Dict[str, str]] = N
191191
def get_audio(self) -> Optional[bytes]:
192192
if self.content_type == MessageContentType.AUDIO:
193193
return base64.b64decode(self.content)
194-
return None
194+
return b""
195195

196196

197197
class ChatStatus(str, Enum):

cozepy/websockets/audio/speech/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def speech_update(self, event: SpeechUpdateEvent) -> None:
121121
self._input_queue.put(event)
122122

123123
def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
124-
event_id = message.get("event_id") or ""
124+
event_id = message.get("id") or ""
125125
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
126126
event_type = message.get("event_type") or ""
127127
data = message.get("data") or {}
@@ -235,7 +235,7 @@ async def speech_update(self, data: SpeechUpdateEvent.Data) -> None:
235235
await self._input_queue.put(SpeechUpdateEvent.model_validate({"data": data}))
236236

237237
def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
238-
event_id = message.get("event_id") or ""
238+
event_id = message.get("id") or ""
239239
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
240240
event_type = message.get("event_type") or ""
241241
data = message.get("data") or {}

cozepy/websockets/audio/transcriptions/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ def serialize_delta(self, delta: bytes, _info):
3030
event_type: WebsocketsEventType = WebsocketsEventType.INPUT_AUDIO_BUFFER_APPEND
3131
data: Data
3232

33+
def _dump_without_delta(self):
34+
return {
35+
"id": self.id,
36+
"type": self.event_type.value,
37+
"detail": self.detail,
38+
"data": {
39+
"delta_length": len(self.data.delta) if self.data and self.data.delta else 0,
40+
},
41+
}
42+
3343

3444
# req
3545
class InputAudioBufferCompleteEvent(WebsocketsEvent):
@@ -127,7 +137,7 @@ def input_audio_buffer_complete(self) -> None:
127137
self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({}))
128138

129139
def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
130-
event_id = message.get("event_id") or ""
140+
event_id = message.get("id") or ""
131141
event_type = message.get("event_type") or ""
132142
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
133143
data = message.get("data") or {}
@@ -250,7 +260,7 @@ async def input_audio_buffer_complete(self) -> None:
250260
await self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({}))
251261

252262
def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
253-
event_id = message.get("event_id") or ""
263+
event_id = message.get("id") or ""
254264
event_type = message.get("event_type") or ""
255265
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
256266
data = message.get("data") or {}

cozepy/websockets/chat/__init__.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ class ChatCreatedEvent(WebsocketsEvent):
5858
event_type: WebsocketsEventType = WebsocketsEventType.CHAT_CREATED
5959

6060

61+
# resp
62+
class ChatUpdatedEvent(WebsocketsEvent):
63+
event_type: WebsocketsEventType = WebsocketsEventType.CHAT_UPDATED
64+
data: ChatUpdateEvent.Data
65+
66+
6167
# resp
6268
class ConversationChatCreatedEvent(WebsocketsEvent):
6369
event_type: WebsocketsEventType = WebsocketsEventType.CONVERSATION_CHAT_CREATED
@@ -107,6 +113,9 @@ class WebsocketsChatEventHandler(WebsocketsBaseEventHandler):
107113
def on_chat_created(self, cli: "WebsocketsChatClient", event: ChatCreatedEvent):
108114
pass
109115

116+
def on_chat_updated(self, cli: "WebsocketsChatClient", event: ChatUpdatedEvent):
117+
pass
118+
110119
def on_input_audio_buffer_completed(self, cli: "WebsocketsChatClient", event: InputAudioBufferCompletedEvent):
111120
pass
112121

@@ -151,6 +160,7 @@ def __init__(
151160
on_event = on_event.to_dict(
152161
{
153162
WebsocketsEventType.CHAT_CREATED: on_event.on_chat_created,
163+
WebsocketsEventType.CHAT_UPDATED: on_event.on_chat_updated,
154164
WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed,
155165
WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created,
156166
WebsocketsEventType.CONVERSATION_CHAT_IN_PROGRESS: on_event.on_conversation_chat_in_progress,
@@ -188,7 +198,7 @@ def input_audio_buffer_complete(self) -> None:
188198
self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({}))
189199

190200
def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
191-
event_id = message.get("event_id") or ""
201+
event_id = message.get("id") or ""
192202
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
193203
event_type = message.get("event_type") or ""
194204
data = message.get("data") or {}
@@ -199,6 +209,14 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
199209
"detail": detail,
200210
}
201211
)
212+
elif event_type == WebsocketsEventType.CHAT_UPDATED.value:
213+
return ChatUpdatedEvent.model_validate(
214+
{
215+
"id": event_id,
216+
"detail": detail,
217+
"data": ChatUpdateEvent.Data.model_validate(data),
218+
}
219+
)
202220
elif event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value:
203221
return InputAudioBufferCompletedEvent.model_validate(
204222
{
@@ -299,6 +317,9 @@ class AsyncWebsocketsChatEventHandler(AsyncWebsocketsBaseEventHandler):
299317
async def on_chat_created(self, cli: "AsyncWebsocketsChatClient", event: ChatCreatedEvent):
300318
pass
301319

320+
async def on_chat_updated(self, cli: "AsyncWebsocketsChatClient", event: ChatUpdatedEvent):
321+
pass
322+
302323
async def on_input_audio_buffer_completed(
303324
self, cli: "AsyncWebsocketsChatClient", event: InputAudioBufferCompletedEvent
304325
):
@@ -355,6 +376,7 @@ def __init__(
355376
on_event = on_event.to_dict(
356377
{
357378
WebsocketsEventType.CHAT_CREATED: on_event.on_chat_created,
379+
WebsocketsEventType.CHAT_UPDATED: on_event.on_chat_updated,
358380
WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED: on_event.on_input_audio_buffer_completed,
359381
WebsocketsEventType.CONVERSATION_CHAT_CREATED: on_event.on_conversation_chat_created,
360382
WebsocketsEventType.CONVERSATION_CHAT_IN_PROGRESS: on_event.on_conversation_chat_in_progress,
@@ -392,7 +414,7 @@ async def input_audio_buffer_complete(self) -> None:
392414
await self._input_queue.put(InputAudioBufferCompleteEvent.model_validate({}))
393415

394416
def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
395-
event_id = message.get("event_id") or ""
417+
event_id = message.get("id") or ""
396418
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
397419
event_type = message.get("event_type") or ""
398420
data = message.get("data") or {}
@@ -403,6 +425,14 @@ def _load_event(self, message: Dict) -> Optional[WebsocketsEvent]:
403425
"detail": detail,
404426
}
405427
)
428+
elif event_type == WebsocketsEventType.CHAT_UPDATED.value:
429+
return ChatUpdatedEvent.model_validate(
430+
{
431+
"id": event_id,
432+
"detail": detail,
433+
"data": ChatUpdateEvent.Data.model_validate(data),
434+
}
435+
)
406436
elif event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_COMPLETED.value:
407437
return InputAudioBufferCompletedEvent.model_validate(
408438
{

cozepy/websockets/ws.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class WebsocketsEventType(str, Enum):
9393
CONVERSATION_CHAT_SUBMIT_TOOL_OUTPUTS = "conversation.chat.submit_tool_outputs" # send tool outputs to server
9494
# resp
9595
CHAT_CREATED = "chat.created"
96+
CHAT_UPDATED = "chat.updated"
9697
# INPUT_AUDIO_BUFFER_COMPLETED = "input_audio_buffer.completed" # received `input_audio_buffer.complete` event
9798
CONVERSATION_CHAT_CREATED = "conversation.chat.created" # audio ast completed, chat started
9899
CONVERSATION_CHAT_IN_PROGRESS = "conversation.chat.in_progress"
@@ -109,7 +110,7 @@ class Detail(BaseModel):
109110
logid: Optional[str] = None
110111

111112
event_type: WebsocketsEventType
112-
event_id: Optional[str] = None
113+
id: Optional[str] = None
113114
detail: Optional[Detail] = None
114115

115116

@@ -118,7 +119,7 @@ class WebsocketsErrorEvent(WebsocketsEvent):
118119
data: CozeAPIError
119120

120121

121-
class InputAudio(CozeModel):
122+
class InputAudio(BaseModel):
122123
format: Optional[str]
123124
codec: Optional[str]
124125
sample_rate: Optional[int]
@@ -266,7 +267,7 @@ def _receive_loop(self) -> None:
266267
self._handle_error(e)
267268

268269
def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]:
269-
event_id = message.get("event_id") or ""
270+
event_id = message.get("id") or ""
270271
event_type = message.get("event_type") or ""
271272
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
272273
data = message.get("data") or {}
@@ -466,7 +467,7 @@ async def _receive_loop(self) -> None:
466467
await self._handle_error(e)
467468

468469
def _load_all_event(self, message: Dict) -> Optional[WebsocketsEvent]:
469-
event_id = message.get("event_id") or ""
470+
event_id = message.get("id") or ""
470471
event_type = message.get("event_type") or ""
471472
detail = WebsocketsEvent.Detail.model_validate(message.get("detail") or {})
472473
data = message.get("data") or {}
@@ -553,7 +554,15 @@ async def _close(self) -> None:
553554
async def _send_event(self, event: Optional[WebsocketsEvent] = None) -> None:
554555
if not event or not self._ws:
555556
return
556-
log_debug("[%s] send event, type=%s", self._path, event.event_type.value)
557+
if event.event_type == WebsocketsEventType.INPUT_AUDIO_BUFFER_APPEND:
558+
log_debug(
559+
"[%s] send event, type=%s, event=%s",
560+
self._path,
561+
event.event_type.value,
562+
json.dumps(event._dump_without_delta()), # type: ignore
563+
)
564+
else:
565+
log_debug("[%s] send event, type=%s, event=%s", self._path, event.event_type.value, event.model_dump_json())
557566
await self._ws.send(event.model_dump_json())
558567

559568

0 commit comments

Comments
 (0)