diff --git a/Cargo.lock b/Cargo.lock index 4603f77d206..1b73ff57ebb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5101,6 +5101,7 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 996290ed3ba..da8e83538df 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -18,7 +18,7 @@ async-trait = "0.1.74" async-stream = "0.3.5" axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" -text-generation-router = { path = "../../router" } +text-generation-router = { path = "../../router", features = ["engine-state"] } clap = { version = "4.4.5", features = ["derive", "env"] } grpc-metadata = { path = "../grpc-metadata" } futures = "0.3.28" @@ -37,13 +37,13 @@ slotmap = "1.0.7" thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = [ - "rt", - "rt-multi-thread", - "parking_lot", - "signal", - "sync", + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", ] } -tokio-stream = "0.1.14" +tokio-stream = { version = "0.1.14", features = ["sync"] } tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.21.0" @@ -51,7 +51,7 @@ tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } init-tracing-opentelemetry = { version = "0.14.1", features = [ - "opentelemetry-otlp", + "opentelemetry-otlp", ] } minijinja = { workspace = true } minijinja-contrib = { workspace = true } diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 98e8d76f09f..bfa87773330 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -6,20 +6,32 @@ use crate::queue::{Entry, Queue}; use async_trait::async_trait; use nohash_hasher::IntMap; use std::sync::Arc; -use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; +use text_generation_router::infer::{ + Backend, EngineState, GeneratedText, InferError, InferStreamResponse, +}; use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::{FinishReason, PrefillToken, Token}; +use tokio::sync::broadcast::{channel, Receiver as BroadcastReceiver, Sender as BroadcastSender}; use tokio::sync::mpsc::error::SendError; -use tokio::sync::{mpsc, Notify}; +use tokio::sync::{mpsc, Notify, RwLock}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{info_span, instrument, Instrument, Span}; + pub struct BackendV3 { + /// Internal batching state exposing info for the proxy + state: Arc>, + /// Request queue queue: Queue, + + /// Events streaming channel + state_events: (BroadcastSender, BroadcastReceiver), + /// Notify batcher on queue appends batching_task_notifier: Arc, + /// Client clone, used for health checks to skip the queue client: ShardedClient, } @@ -41,6 +53,12 @@ impl BackendV3 { let block_size = shard_info.block_size; + let state_events = channel(1); + let state = Arc::new(RwLock::new(EngineState::new( + max_batch_total_tokens, + 2 * max_batch_total_tokens, + ))); + let queue = Queue::new( shard_info.requires_padding, block_size, @@ -49,6 +67,8 @@ impl BackendV3 { shard_info.speculate, max_batch_total_tokens, shard_info.support_chunking, + Arc::clone(&state), + state_events.0.clone(), ); let batching_task_notifier = Arc::new(Notify::new()); @@ -62,10 +82,14 @@ impl BackendV3 { max_batch_size, shard_info.support_chunking, queue.clone(), + state.clone(), + state_events.0.clone(), batching_task_notifier.clone(), )); Self { + state, + state_events, queue, batching_task_notifier, client, @@ -112,6 +136,10 @@ impl Backend for BackendV3 { .is_ok() } + fn events(&self) -> BroadcastReceiver { + self.state_events.0.subscribe() + } + fn start_health(&self) -> bool { true } @@ -135,6 +163,8 @@ pub(crate) async fn batching_task( max_batch_size: Option, support_chunking: bool, queue: Queue, + engine_state: Arc>, + batch_events: BroadcastSender, notifier: Arc, ) { // Infinite loop @@ -170,6 +200,22 @@ pub(crate) async fn batching_task( metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); + // Dispatch new state to the proxy + { + // Critical section, doing as little as possible here + { + let mut engine_state = engine_state.write().await; + engine_state.in_flight = batch_max_tokens; + } + + // Send new state to the channel for broadcasting + if let Err(err) = batch_events.send(*engine_state.read().await) { + tracing::warn!( + "Failed to send BatchEvent::BatchChanged({batch_max_tokens}): {err}" + ) + } + } + let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let (min_size, max_size, prefill_token_budget) = if support_chunking { diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 249eebf7615..71158cf0a66 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -6,14 +6,17 @@ use crate::client::{ use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::max; use std::collections::VecDeque; -use text_generation_router::infer::InferError; +use std::sync::Arc; use text_generation_router::infer::InferStreamResponse; +use text_generation_router::infer::{EngineState, InferError}; use text_generation_router::validation::{ Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, }; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::broadcast::Sender as BroadcastSender; +use tokio::sync::{mpsc, oneshot, RwLock}; use tokio::time::Instant; +use tracing::log::warn; use tracing::{info_span, instrument, Instrument, Span}; /// Queue entry @@ -51,6 +54,8 @@ impl Queue { speculate: u32, max_batch_total_tokens: u32, support_chunking: bool, + engine_state: Arc>, + queue_events: BroadcastSender, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -64,7 +69,9 @@ impl Queue { speculate, max_batch_total_tokens, support_chunking, + engine_state, queue_receiver, + queue_events, )); Self { queue_sender } @@ -113,7 +120,7 @@ impl Queue { } } -// Background task responsible of the queue state +// Background task responsible for the queue state #[allow(clippy::too_many_arguments)] async fn queue_task( requires_padding: bool, @@ -123,7 +130,9 @@ async fn queue_task( speculate: u32, max_batch_total_tokens: u32, support_chunking: bool, + engine_state: Arc>, mut receiver: mpsc::UnboundedReceiver, + queue_events: BroadcastSender, ) { let mut state = State::new( requires_padding, @@ -138,8 +147,29 @@ async fn queue_task( while let Some(cmd) = receiver.recv().await { match cmd { QueueCommand::Append(entry, span) => { + let entry_num_tokens = entry.request.input_length; span.in_scope(|| state.append(*entry)); metrics::gauge!("tgi_queue_size").increment(1.0); + metrics::gauge!("tgi_queue_size_tokens").increment(entry_num_tokens); + + // Dispatch new state to the proxy + { + // Lock free operation (read) + let num_queued_tokens = engine_state.read().await.in_queue; + { + // Critical section, doing as little as possible here + let mut engine_state = engine_state.write().await; + engine_state.in_queue = num_queued_tokens + entry_num_tokens; + } + + // Send new state to the channel for broadcasting + if let Err(err) = queue_events.send(*engine_state.read().await) { + tracing::warn!( + "Failed to send BatchEvent::QueueChanged({}): {err}", + num_queued_tokens + entry_num_tokens + ) + } + } } QueueCommand::NextBatch { min_size, @@ -154,7 +184,33 @@ async fn queue_task( .instrument(span) .await; response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); + + { + let num_batch_tokens = state + .entries + .iter() + .map(|(_, e)| e.request.input_length) + .sum::(); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); + metrics::gauge!("tgi_queue_size_tokens").set(num_batch_tokens as f64); + + // Dispatch new state to the proxy + { + // Critical section, doing as little as possible here + { + let mut engine_state = engine_state.write().await; + engine_state.in_queue = num_batch_tokens; + } + + // Send new state to the channel for broadcasting + if let Err(err) = queue_events.send(*engine_state.read().await) { + tracing::warn!( + "Failed to send BatchEvent::QueueChanged({}): {err}", + num_batch_tokens + ) + } + } + } } } } diff --git a/router/Cargo.toml b/router/Cargo.toml index 9326258daa2..bdaefcc747b 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -31,11 +31,11 @@ serde_json = "1.0.107" thiserror = "1.0.48" tokenizers = { workspace = true } tokio = { version = "1.32.0", features = [ - "rt", - "rt-multi-thread", - "parking_lot", - "signal", - "sync", + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", ] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } @@ -46,7 +46,7 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ - "opentelemetry-otlp", + "opentelemetry-otlp", ] } minijinja = { workspace = true, features = ["loop_controls"] } minijinja-contrib = { workspace = true } @@ -57,9 +57,9 @@ image = "0.25.1" base64 = { workspace = true } sysinfo = "0.30.13" uuid = { version = "1.9.1", default-features = false, features = [ - "v4", - "fast-rng", - "macro-diagnostics", + "v4", + "fast-rng", + "macro-diagnostics", ] } csv = "1.3.0" ureq = "=2.9" @@ -73,5 +73,6 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } [features] default = ["ngrok"] ngrok = ["dep:ngrok"] +engine-state = [] google = [] kserve = [] diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 7eb8a41be57..0123f54fb7b 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -19,12 +19,43 @@ use serde::Serialize; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use thiserror::Error; +use tokio::sync::broadcast::Receiver; use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::instrument; +/// Store real-time information about batch engine usage (expressed in tokens) +#[cfg(feature = "engine-state")] +#[derive(Debug, Copy, Clone, Serialize)] +pub struct EngineState { + /// Number of tokens currently participating in current batch + pub in_flight: u32, + + /// Maximum number of tokens which can participate in a batch + pub in_flight_max: u32, + + /// Number of tokens currently waiting in the queue for future batching + pub in_queue: u32, + + /// Maximum number of tokens which can wait in the queue for future batching + pub in_queue_max: u32, +} + +#[cfg(feature = "engine-state")] +impl EngineState { + #[inline] + pub fn new(in_flight_max: u32, in_queue_max: u32) -> Self { + EngineState { + in_flight: 0, + in_flight_max, + in_queue: 0, + in_queue_max, + } + } +} + #[async_trait] pub trait Backend { fn schedule( @@ -34,6 +65,11 @@ pub trait Backend { async fn health(&self, current_health: bool) -> bool; + /// Gets a reference to receiving-side channel generating events about the current internal + /// batching engine state + #[cfg(feature = "engine-state")] + fn events(&self) -> Receiver; + /// The state of the health on startup /// Typically false, or true if the backend includes /// a warmup phase. @@ -95,6 +131,12 @@ impl Infer { } } + #[cfg(feature = "engine-state")] + #[inline] + pub(crate) fn events(&self) -> Receiver { + self.backend.events() + } + /// Add a new request to the queue and return a stream of InferStreamResponse #[instrument(skip_all)] pub(crate) async fn generate_stream<'a>( diff --git a/router/src/server.rs b/router/src/server.rs index e9aa4612bab..c394f7b8fe3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -62,11 +62,13 @@ use tokio::select; use tokio::signal; use tokio::sync::oneshot; use tokio::time::Instant; +use tokio_stream::wrappers::BroadcastStream; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; + fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec { let offsets = encoding.get_offsets(); let input_ids = encoding.get_ids(); @@ -1501,6 +1503,28 @@ async fn metrics(prom_handle: Extension) -> String { prom_handle.render() } +#[utoipa::path(get, tag = "Text Generation Inference", path = "/state")] +#[instrument(skip_all)] +async fn state( + Extension(infer): Extension, +) -> Result>>, StatusCode> { + if cfg!(feature = "engine-state") { + let stream = infer.events(); + let sse = + Sse::new(BroadcastStream::from(stream).map(|state| { + Event::default().json_data(state.map_err(|err| axum::Error::new(err))?) + })) + .keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(5)) + .text("more_open_models_on_hf"), + ); + Ok(sse) + } else { + Err(StatusCode::NOT_IMPLEMENTED) + } +} + #[derive(Clone, Debug)] pub(crate) struct ComputeType(String); @@ -1520,6 +1544,7 @@ metrics, openai_get_model_info, sagemaker_compatibility, get_chat_tokenize, +state, ), components( schemas( @@ -2171,6 +2196,11 @@ async fn start( "Current batch size" ); metrics::describe_gauge!("tgi_queue_size", metrics::Unit::Count, "Current queue size"); + metrics::describe_gauge!( + "tgi_queue_size_tokens", + metrics::Unit::Count, + "Current queue size in number of tokens" + ); metrics::describe_gauge!( "tgi_batch_current_max_tokens", metrics::Unit::Count, @@ -2367,6 +2397,7 @@ async fn start( .route("/health", get(health)) .route("/ping", get(health)) .route("/metrics", get(metrics)) + .route("/state", get(state)) .route("/v1/models", get(openai_get_model_info)); let compute_type =