From efb20054aaca1727e04364152a04a23a91f89b63 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 27 Feb 2025 16:12:51 +0000 Subject: [PATCH] feat: consolidate streaming and event creation logic and add tests for streaming generations --- router/src/lib.rs | 557 ++++++++++++++++++++++++++++++++++++++----- router/src/server.rs | 315 ++---------------------- 2 files changed, 516 insertions(+), 356 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index dfb68b490e6..f23ef8addad 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -754,59 +754,6 @@ pub(crate) struct Function { pub arguments: String, } -#[allow(clippy::too_many_arguments)] -impl ChatCompletionChunk { - pub(crate) fn new( - model: String, - system_fingerprint: String, - delta: Option, - tool_calls: Option>, - created: u64, - logprobs: Option, - finish_reason: Option, - usage: Option, - tool_name: Option, - ) -> Self { - let delta = match (delta, tool_calls) { - (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage { - role: "assistant".to_string(), - content: delta, - ..Default::default() - }), - (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta { - role: "assistant".to_string(), - tool_calls: vec![DeltaToolCall { - index: 0, - id: String::new(), - r#type: "function".to_string(), - function: Function { - name: tool_name, - arguments: tool_calls[0].to_string(), - }, - }], - }), - (None, None) => ChatCompletionDelta::Chat(TextMessage { - role: "assistant".to_string(), - content: "".to_string(), - ..Default::default() - }), - }; - Self { - id: String::new(), - created, - model, - system_fingerprint, - choices: vec![ChatCompletionChoice { - index: 0, - delta, - logprobs, - finish_reason, - }], - usage, - } - } -} - #[derive(Clone, Deserialize, ToSchema, Serialize)] #[cfg_attr(test, derive(Debug, PartialEq, Default))] pub(crate) struct ChatRequest { @@ -1021,7 +968,7 @@ impl ChatRequest { #[derive(Clone, Deserialize, ToSchema, Serialize)] #[cfg_attr(test, derive(Debug, PartialEq))] -struct StreamOptions { +pub(crate) struct StreamOptions { /// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value. #[schema(example = "true")] include_usage: bool, @@ -1844,9 +1791,306 @@ mod tests { } } +fn create_event( + token_text: &str, + model_id: &str, + system_fingerprint: &str, + tool_name: Option<&str>, + is_tool_arg: bool, + finish_reason: Option, +) -> Event { + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + // Create the delta based on direct pattern matching of parameters + let (delta, finish) = match (tool_name, is_tool_arg, finish_reason) { + (Some(name), _, _) => ( + // Tool call name event + ChatCompletionDelta::Tool(ToolCallDelta { + role: "assistant".to_string(), + tool_calls: vec![DeltaToolCall { + index: 0, + id: String::new(), + r#type: "function".to_string(), + function: Function { + name: Some(name.to_string()), + arguments: String::new(), + }, + }], + }), + None, + ), + (None, true, _) => ( + // Tool call argument event + ChatCompletionDelta::Tool(ToolCallDelta { + role: "assistant".to_string(), + tool_calls: vec![DeltaToolCall { + index: 0, + id: String::new(), + r#type: "function".to_string(), + function: Function { + name: None, + arguments: token_text.to_string(), + }, + }], + }), + None, + ), + (None, false, reason) => ( + // Regular text event + ChatCompletionDelta::Chat(TextMessage { + role: "assistant".to_string(), + content: token_text.to_string(), + ..Default::default() + }), + reason, + ), + }; + + // Create the ChatCompletionChunk with the appropriate delta + let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk { + id: String::new(), + created: current_time, + model: model_id.to_string(), + system_fingerprint: system_fingerprint.to_string(), + choices: vec![ChatCompletionChoice { + index: 0, + delta, + logprobs: None, + finish_reason: finish, + }], + usage: None, + }); + + Event::default() + .json_data(chat_complete) + .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()) +} + +#[derive(serde::Serialize, serde::Deserialize, Debug)] +pub struct ParseFunction { + #[serde(rename = "_name")] + name: String, +} + +#[derive(serde::Serialize, serde::Deserialize, Debug)] +pub struct ToolDecision { + #[serde(rename = "function")] + function: ParseFunction, +} + +#[derive(serde::Serialize, serde::Deserialize, Debug)] +pub struct NoToolDecision { + content: String, +} + +use axum::response::sse::Event; +use serde_json::Value; +use std::convert::Infallible; + +fn get_timestamp() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +// Process stream token into events and mutates the buffers to keep track of the current state +#[allow(clippy::too_many_arguments)] +async fn process_stream_token( + token_text: String, + json_buffer: &mut String, + name_found: &mut bool, + no_tool_chosen: &mut bool, + first_quote_removed: &mut bool, + using_tools: bool, + model_id: &str, + system_fingerprint: &str, + stream_token: &StreamResponse, + stream_options: Option<&StreamOptions>, +) -> (Vec>, bool) { + let mut events = Vec::new(); + let mut should_break = false; + + // Get usage information + let usage = stream_token.details.as_ref().map(|d| Usage { + completion_tokens: d.generated_tokens, + prompt_tokens: d.input_length, + total_tokens: d.input_length + d.generated_tokens, + }); + + json_buffer.push_str(&token_text); + + // Phase 1: Function name discovery + if !*name_found { + // NOTE: when tools are supplied `name_found` is false until the generated buffer contains + // a partial JSON object with $.function._name value. This name determines the type + // of events to emit. If the name is "no_tool", we'll emit the "content" field as a chat + // completion event. Otherwise, we'll emit a tool call name event followed by a tool call + // argument event. In both cases we'll buffer tokens to get the name and then reset the buffer + // to collect the arguments. + if let Ok(ParseResult { + value: ToolDecision { + function: ParseFunction { name }, + }, + last_value_whole, + }) = parse_partial_json(json_buffer) + { + if !last_value_whole { + return (events, should_break); + } + *name_found = true; + if name == "no_tool" { + *no_tool_chosen = true; + } else { + events.push(Ok(create_event( + &token_text, + model_id, + system_fingerprint, + Some(name.as_str()), + false, + None, + ))); + events.push(Ok(create_event( + "{", + model_id, + system_fingerprint, + None, + true, + None, + ))); + } + + // Reset buffer for arguments + json_buffer.clear(); + json_buffer.push('{'); + } + + return (events, should_break); + } + + // Phase 2: Content processing + let is_complete_json = json_buffer.ends_with('}') + && serde_json::from_str::(&json_buffer[..json_buffer.len() - 1]).is_ok(); + let mut edited_token = token_text; + + // Handle different flows based on context + if using_tools { + if *no_tool_chosen && !is_complete_json { + // Content-only flow + if let Ok(ParseResult { + value: _, + last_value_whole, + }) = parse_partial_json::(json_buffer) + { + let cleaned_token = if !*first_quote_removed { + // trim start until the first quote + *first_quote_removed = true; + edited_token + .trim_start() + .strip_prefix('"') + .unwrap_or(&edited_token) + .to_string() + } else if last_value_whole { + should_break = true; + // trim end until the last quote + edited_token + .trim_end() + .strip_suffix('"') + .unwrap_or(&edited_token) + .to_string() + } else { + edited_token.to_string() + }; + + if !cleaned_token.is_empty() { + events.push(Ok(create_event( + &cleaned_token, + model_id, + system_fingerprint, + None, + false, + None, + ))); + } + } + } else { + // Tool with arguments flow + if is_complete_json { + edited_token.truncate(edited_token.len() - 1); + should_break = true; + } + events.push(Ok(create_event( + &edited_token, + model_id, + system_fingerprint, + None, + true, + None, + ))); + } + } else { + // Standard chat completion flow + if let Some(details) = stream_token.details.as_ref() { + let finish_reason = details.finish_reason.format(true); + let text = if details.finish_reason == FinishReason::Length { + &edited_token + } else { + "" + }; + events.push(Ok(create_event( + text, + model_id, + system_fingerprint, + None, + false, + Some(finish_reason), + ))); + should_break = true; + } else { + events.push(Ok(create_event( + &edited_token, + model_id, + system_fingerprint, + None, + false, + None, + ))); + } + } + + // Emit usage data when requested + if let (Some(usage_data), true) = ( + usage, + stream_options.as_ref().is_some_and(|o| o.include_usage), + ) { + let current_time = get_timestamp(); + + let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk { + id: String::new(), + created: current_time, + model: model_id.to_string(), + system_fingerprint: system_fingerprint.to_string(), + choices: vec![], + usage: Some(usage_data), + }); + + events.push(Ok(Event::default() + .json_data(chat_complete) + .unwrap_or_else(|e| { + InferError::StreamSerializationError(e.to_string()).into() + }))); + } + + (events, should_break) +} + #[cfg(test)] mod tool_streaming_tests { use super::*; + use futures::StreamExt; // Test json balancing and completion #[test] @@ -1971,7 +2215,7 @@ mod tool_streaming_tests { // Function decision let result = parse_partial_json::(&json_buffers[0]); - println!("{:?}", result); + assert!(result.is_ok()); assert_eq!(result.unwrap().value.function.name, "no_tool"); @@ -1990,7 +2234,7 @@ mod tool_streaming_tests { // Function decision let result = parse_partial_json::(&json_buffers[0]); - println!("{:?}", result); + assert!(result.is_ok()); assert_eq!(result.unwrap().value.function.name, "no_tool"); @@ -2053,4 +2297,205 @@ mod tool_streaming_tests { assert_eq!(parsed.value.as_object().unwrap()["format"], "fahrenheit"); assert_eq!(parsed.last_value_whole, true); } + + #[tokio::test] + async fn test_streaming_no_tool_decision() { + let tokens_to_stream = vec![ + "{\"".to_string(), + "function".to_string(), + "\":".to_string(), + " {\"".to_string(), + "_".to_string(), + "name".to_string(), + "\":".to_string(), + " \"".to_string(), + "no".to_string(), + "_tool".to_string(), + "\",".to_string(), + " \"".to_string(), + "content".to_string(), + "\":".to_string(), + " \"".to_string(), + "I".to_string(), // Event 1 + " am".to_string(), // Event 2 + " a".to_string(), // Event 3 + " helpful".to_string(), // Event 4 + " assistant".to_string(), // Event 5 + "!\"".to_string(), // Event 6 (with trailing quore removed) + ]; + + let event_to_stream = tokens_to_stream + .into_iter() + .enumerate() + .map(|(i, token)| StreamResponse { + index: i as u32, + token: Token { + id: 0, + text: token, + logprob: 0.0, + special: false, + }, + top_tokens: vec![], + generated_text: None, + details: None, + }) + .collect::>(); + + // Create a stream from our test events + let stream = futures::stream::iter( + event_to_stream + .into_iter() + .map(Ok::), + ); + + // Initialize variables + let mut json_buffer = String::new(); + let mut name_found = false; + let mut no_tool_chosen = false; + let mut first_quote_removed = false; + let mut events = Vec::new(); + + let using_tools = true; + let model_id = "gpt2"; + let system_fingerprint = "test"; + + let stream_options = Some(StreamOptions { + include_usage: true, + }); + + // Use StreamExt to get access to next() method + use futures::StreamExt; + let mut stream = Box::pin(stream); + + // Process the stream asynchronously + while let Some(Ok(stream_token)) = stream.next().await { + let (new_events, should_break) = process_stream_token( + stream_token.token.text.clone(), + &mut json_buffer, + &mut name_found, + &mut no_tool_chosen, + &mut first_quote_removed, + using_tools, + model_id, + system_fingerprint, + &stream_token, + stream_options.as_ref(), + ) + .await; + + events.extend(new_events); + if should_break { + break; + } + } + // Expect 6 events (the relevant tokens within content) + assert_eq!(events.len(), 6); + // "I am a helpful assistant!" + } + + #[tokio::test] + async fn test_streaming_tool_decision() { + let tokens_to_stream = vec![ + "{\"".to_string(), + "function".to_string(), + "\":".to_string(), + " {\"".to_string(), + "_".to_string(), + "name".to_string(), + "\":".to_string(), + " \"".to_string(), + "get".to_string(), + "_current".to_string(), + "_weather".to_string(), + "\",".to_string(), + // Event 1 is the function name + // Event 2 is the start of the arguments "{" + " \"".to_string(), // Event 3 + "location".to_string(), // Event 4 + "\":".to_string(), // Event 5 + " \"".to_string(), // Event 6 + "San".to_string(), // Event 7 + " Francisco".to_string(), // Event 8 + ",".to_string(), // Event 9 + " CA".to_string(), // Event 10 + "\",".to_string(), // Event 11 + " \"".to_string(), // Event 12 + "format".to_string(), // Event 13 + "\":".to_string(), // Event 14 + " \"".to_string(), // Event 15 + "c".to_string(), // Event 16 + "elsius".to_string(), // Event 17 + "\"}}".to_string(), // Event 18 retained (trailing brace removed) + ]; + + let event_to_stream = tokens_to_stream + .into_iter() + .enumerate() + .map(|(i, token)| StreamResponse { + index: i as u32, + token: Token { + id: 0, + text: token, + logprob: 0.0, + special: false, + }, + top_tokens: vec![], + generated_text: None, + details: None, + }) + .collect::>(); + + // Create a stream from our test events + let stream = futures::stream::iter( + event_to_stream + .into_iter() + .map(Ok::), + ); + + // Initialize variables + let mut json_buffer = String::new(); + let mut name_found = false; + let mut no_tool_chosen = false; + let mut first_quote_removed = false; + let mut events = Vec::new(); + + let using_tools = true; + let model_id = "gpt2"; + let system_fingerprint = "test"; + + let stream_options = Some(StreamOptions { + include_usage: true, + }); + + let mut stream = Box::pin(stream); + + // Process the stream asynchronously + while let Some(Ok(stream_token)) = stream.next().await { + let (new_events, should_break) = process_stream_token( + stream_token.token.text.clone(), + &mut json_buffer, + &mut name_found, + &mut no_tool_chosen, + &mut first_quote_removed, + using_tools, + model_id, + system_fingerprint, + &stream_token, + stream_options.as_ref(), + ) + .await; + + events.extend(new_events); + if should_break { + break; + } + } + + for event in &events { + println!("{:?}", event); + } + + assert_eq!(events.len(), 18); + // "{ "location": "San Francisco, CA", "format": "celsius"}" + } } diff --git a/router/src/server.rs b/router/src/server.rs index de0cf53d260..beb73dea273 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -6,6 +6,7 @@ use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kserve_model_metadata, kserve_model_metadata_ready, }; +use crate::process_stream_token; use crate::sagemaker::{ sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse, __path_sagemaker_compatibility, @@ -13,7 +14,6 @@ use crate::sagemaker::{ use crate::validation::ValidationError; use crate::vertex::vertex_compatibility; use crate::ChatTokenizeResponse; -use crate::{parse_partial_json, ParseResult}; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -1114,116 +1114,6 @@ pub(crate) async fn completions( } } -/// Creates an event based on the token text and event type parameters. -/// `token_text` - The text to include (extract from StreamResponse.token.text or str) -/// `model_id` - Model identifier string -/// `system_fingerprint` - System fingerprint string -/// `tool_name` - If provided, creates a tool call name event -/// `is_tool_arg` - If true, creates a tool call argument event -fn create_event( - token_text: &str, - model_id: &str, - system_fingerprint: &str, - tool_name: Option<&str>, - is_tool_arg: bool, - finish_reason: Option, -) -> Event { - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - - let chat_complete = if let Some(tool_name) = tool_name { - // Tool call name event - let tool_delta = ChatCompletionDelta::Tool(ToolCallDelta { - role: "assistant".to_string(), - tool_calls: vec![DeltaToolCall { - index: 0, - id: String::new(), - r#type: "function".to_string(), - function: Function { - name: Some(tool_name.to_string()), - arguments: "".to_string(), - }, - }], - }); - - CompletionType::ChatCompletionChunk(ChatCompletionChunk { - id: String::new(), - created: current_time, - model: model_id.to_string(), - system_fingerprint: system_fingerprint.to_string(), - choices: vec![ChatCompletionChoice { - index: 0, - delta: tool_delta, - logprobs: None, - finish_reason: None, - }], - usage: None, - }) - } else if is_tool_arg { - // Tool call argument event - let tool_delta = ChatCompletionDelta::Tool(ToolCallDelta { - role: "assistant".to_string(), - tool_calls: vec![DeltaToolCall { - index: 0, - id: String::new(), - r#type: "function".to_string(), - function: Function { - name: None, - arguments: token_text.to_string(), - }, - }], - }); - - CompletionType::ChatCompletionChunk(ChatCompletionChunk { - id: String::new(), - created: current_time, - model: model_id.to_string(), - system_fingerprint: system_fingerprint.to_string(), - choices: vec![ChatCompletionChoice { - index: 0, - delta: tool_delta, - logprobs: None, - finish_reason: None, - }], - usage: None, - }) - } else { - // usage, finish_reason - if finish_reason.is_some() { - CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( - model_id.to_string(), - system_fingerprint.to_string(), - Some(token_text.to_string()), - None, - current_time, - None, - finish_reason, - None, - None, - )) - } else { - // Chat completion event - CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( - model_id.to_string(), - system_fingerprint.to_string(), - Some(token_text.to_string()), - None, - current_time, - None, - None, - None, - None, - )) - } - }; - - Event::default() - .json_data(chat_complete) - .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()) -} - /// Generate tokens #[utoipa::path( post, @@ -1292,22 +1182,6 @@ pub(crate) async fn chat_completions( .as_secs() }; - #[derive(serde::Serialize, serde::Deserialize, Debug)] - pub struct Function { - #[serde(rename = "_name")] - name: String, - } - - #[derive(serde::Serialize, serde::Deserialize, Debug)] - pub struct ToolDecision { - function: Function, - } - - #[derive(serde::Serialize, serde::Deserialize, Debug)] - pub struct NoToolDecision { - content: String, - } - if stream { let (headers, response_stream) = generate_stream_internal(infer, compute_type, Json(generate_request), span).await; @@ -1322,184 +1196,25 @@ pub(crate) async fn chat_completions( // Process stream tokens while let Some(Ok(stream_token)) = response_stream.next().await { let token_text = stream_token.token.text.clone(); - let mut events = Vec::new(); - let mut should_break = false; - - // Get usage information - let usage = stream_token.details.as_ref().map(|d| Usage { - completion_tokens: d.generated_tokens, - prompt_tokens: d.input_length, - total_tokens: d.input_length + d.generated_tokens, - }); - - json_buffer.push_str(&token_text); - - // Phase 1: Function name discovery - if !name_found { - // NOTE: when tools are supplied `name_found` is false until the generated buffer contains - // a partial JSON object with $.function._name value. This name determines the type - // of events to emit. If the name is "no_tool", we'll emit the "content" field as a chat - // completion event. Otherwise, we'll emit a tool call name event followed by a tool call - // argument event. In both cases we'll buffer tokens to get the name and then reset the buffer - // to collect the arguments. - if let Ok(ParseResult { - value: ToolDecision { - function: Function { name }, - }, - last_value_whole, - }) = parse_partial_json(&json_buffer) - { - if !last_value_whole { - continue; - } - name_found = true; - if name == "no_tool" { - no_tool_chosen = true; - } else { - events.push(create_event( - &token_text, - &model_id, - &system_fingerprint, - Some(name.as_str()), - false, - None, - )); - events.push(create_event( - "{", - &model_id, - &system_fingerprint, - None, - true, - None, - )); - } - - // Reset buffer for arguments - json_buffer.clear(); - json_buffer.push('{'); - } - - for event in events { - yield Ok::(event); - } - continue; - } - - // Phase 2: Content processing - let is_complete_json = json_buffer.ends_with('}') - && serde_json::from_str::(&json_buffer[..json_buffer.len() - 1]).is_ok(); - let mut edited_token = token_text; - - // Handle different flows based on context - if using_tools { - if no_tool_chosen && !is_complete_json { - // Content-only flow - if let Ok(ParseResult { - value: _, - last_value_whole, - }) = parse_partial_json::(&json_buffer) - { - let cleaned_token = if !first_quote_removed { - // trim start unil the first quote - first_quote_removed = true; - edited_token - .trim_start() - .strip_prefix('"') - .unwrap_or(&edited_token) - .to_string() - } else if last_value_whole { - should_break = true; - // trim end until the last quote - edited_token - .trim_end() - .strip_suffix('"') - .unwrap_or(&edited_token) - .to_string() - } else { - edited_token.to_string() - }; - - if !cleaned_token.is_empty() { - events.push(create_event( - &cleaned_token, - &model_id, - &system_fingerprint, - None, - false, - None, - )); - } - } - } else { - // Tool with arguments flow - if is_complete_json { - edited_token.truncate(edited_token.len() - 1); - should_break = true; - } - events.push(create_event( - &edited_token, - &model_id, - &system_fingerprint, - None, - true, - None, - )); - } - } else { - // Standard chat completion flow - if let Some(details) = stream_token.details.as_ref() { - let finish_reason = details.finish_reason.format(true); - let text = if details.finish_reason == FinishReason::Length { - &edited_token - } else { - "" - }; - events.push(create_event( - text, - &model_id, - &system_fingerprint, - None, - false, - Some(finish_reason), - )); - should_break = true; - } else { - events.push(create_event( - &edited_token, - &model_id, - &system_fingerprint, - None, - false, - None, - )); - } - } + // Process stream token into a series of events and a break signal + let (events, should_break) = process_stream_token( + token_text, + &mut json_buffer, + &mut name_found, + &mut no_tool_chosen, + &mut first_quote_removed, + using_tools, + &model_id, + &system_fingerprint, + &stream_token, + stream_options.as_ref(), + ).await; // Emit all collected events for event in events { - yield Ok::(event); + yield event; } - // Emit usage data when requested - if let (Some(usage_data), true) = ( - usage, - stream_options.as_ref().is_some_and(|o| o.include_usage) - ) { - let current_time = get_timestamp(); - - let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk { - id: String::new(), - created: current_time, - model: model_id.clone(), - system_fingerprint: system_fingerprint.clone(), - choices: vec![], - usage: Some(usage_data), - }); - - yield Ok(Event::default() - .json_data(chat_complete) - .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())); - } if should_break { break; }