diff --git a/router/src/lib.rs b/router/src/lib.rs index 414d38ed6e5..ded939589d8 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -790,12 +790,15 @@ impl ChatCompletionChunk { created, model, system_fingerprint, - choices: vec![ChatCompletionChoice { - index: 0, - delta, - logprobs, - finish_reason, - }], + choices: match usage { + None => vec![ChatCompletionChoice { + index: 0, + delta, + logprobs, + finish_reason, + }], + _ => vec![], + }, usage, } } diff --git a/router/src/server.rs b/router/src/server.rs index 9e57af27562..b6021282820 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1124,7 +1124,6 @@ enum StreamState { fn create_event_from_stream_token( stream_token: &StreamResponse, logprobs: bool, - stream_options: Option, inner_using_tools: bool, system_fingerprint: String, model_id: String, @@ -1153,8 +1152,43 @@ fn create_event_from_stream_token( }; let (usage, finish_reason) = match &stream_token.details { + Some(details) => (None, Some(details.finish_reason.format(true))), + None => (None, None), + }; + + let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + content, + tool_calls, + current_time, + logprobs, + finish_reason, + usage, + )); + + event.json_data(chat_complete).unwrap_or_else(|e| { + println!("Failed to serialize ChatCompletionChunk: {:?}", e); + Event::default() + }) +} + +/// Convert a StreamResponse into an Event to be sent over SSE +fn create_usage_event_from_stream_token( + stream_token: &StreamResponse, + stream_options: Option, + system_fingerprint: String, + model_id: String, +) -> Event { + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + + let usage = match &stream_token.details { Some(details) => { - let usage = if stream_options + if stream_options .as_ref() .map(|s| s.include_usage) .unwrap_or(false) @@ -1169,20 +1203,19 @@ fn create_event_from_stream_token( }) } else { None - }; - (usage, Some(details.finish_reason.format(true))) + } } - None => (None, None), + None => None, }; let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( model_id.clone(), system_fingerprint.clone(), - content, - tool_calls, + None, + None, current_time, - logprobs, - finish_reason, + None, + None, usage, )); @@ -1307,7 +1340,6 @@ pub(crate) async fn chat_completions( let event = create_event_from_stream_token( stream_token, logprobs, - stream_options.clone(), response_as_tool, system_fingerprint.clone(), model_id.clone(), @@ -1369,13 +1401,25 @@ pub(crate) async fn chat_completions( let event = create_event_from_stream_token( &stream_token, logprobs, - stream_options.clone(), response_as_tool, system_fingerprint.clone(), model_id.clone(), ); yield Ok::(event); + + if stream_token.details.is_some() && stream_options + .as_ref() + .map(|s| s.include_usage) + .unwrap_or(false) { + let usage_event = create_usage_event_from_stream_token( + &stream_token, + stream_options.clone(), + system_fingerprint.clone(), + model_id.clone(), + ); + yield Ok::(usage_event); + } } } }