Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose the real-time internal state of the batcher through SSE #3065

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions backends/v3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async-trait = "0.1.74"
async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16"
text-generation-router = { path = "../../router" }
text-generation-router = { path = "../../router", features = ["engine-state"] }
clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28"
Expand All @@ -37,21 +37,21 @@ slotmap = "1.0.7"
thiserror = "1.0.48"
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = [
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
] }
tokio-stream = "0.1.14"
tokio-stream = { version = "0.1.14", features = ["sync"] }
tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
init-tracing-opentelemetry = { version = "0.14.1", features = [
"opentelemetry-otlp",
"opentelemetry-otlp",
] }
minijinja = { workspace = true }
minijinja-contrib = { workspace = true }
Expand Down
50 changes: 48 additions & 2 deletions backends/v3/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,32 @@ use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::infer::{
Backend, EngineState, GeneratedText, InferError, InferStreamResponse,
};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::broadcast::{channel, Receiver as BroadcastReceiver, Sender as BroadcastSender};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::sync::{mpsc, Notify, RwLock};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};


pub struct BackendV3 {
/// Internal batching state exposing info for the proxy
state: Arc<RwLock<EngineState>>,

/// Request queue
queue: Queue,

/// Events streaming channel
state_events: (BroadcastSender<EngineState>, BroadcastReceiver<EngineState>),

/// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>,

/// Client clone, used for health checks to skip the queue
client: ShardedClient,
}
Expand All @@ -41,6 +53,12 @@ impl BackendV3 {

let block_size = shard_info.block_size;

let state_events = channel(1);
let state = Arc::new(RwLock::new(EngineState::new(
max_batch_total_tokens,
2 * max_batch_total_tokens,
)));

let queue = Queue::new(
shard_info.requires_padding,
block_size,
Expand All @@ -49,6 +67,8 @@ impl BackendV3 {
shard_info.speculate,
max_batch_total_tokens,
shard_info.support_chunking,
Arc::clone(&state),
state_events.0.clone(),
);
let batching_task_notifier = Arc::new(Notify::new());

Expand All @@ -62,10 +82,14 @@ impl BackendV3 {
max_batch_size,
shard_info.support_chunking,
queue.clone(),
state.clone(),
state_events.0.clone(),
batching_task_notifier.clone(),
));

Self {
state,
state_events,
queue,
batching_task_notifier,
client,
Expand Down Expand Up @@ -112,6 +136,10 @@ impl Backend for BackendV3 {
.is_ok()
}

fn events(&self) -> BroadcastReceiver<EngineState> {
self.state_events.0.subscribe()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should send the current state as soon as a new connection is open (new subscribe)

}

fn start_health(&self) -> bool {
true
}
Expand All @@ -135,6 +163,8 @@ pub(crate) async fn batching_task(
max_batch_size: Option<usize>,
support_chunking: bool,
queue: Queue,
engine_state: Arc<RwLock<EngineState>>,
batch_events: BroadcastSender<EngineState>,
notifier: Arc<Notify>,
) {
// Infinite loop
Expand Down Expand Up @@ -170,6 +200,22 @@ pub(crate) async fn batching_task(
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);

// Dispatch new state to the proxy
{
// Critical section, doing as little as possible here
{
let mut engine_state = engine_state.write().await;
engine_state.in_flight = batch_max_tokens;
}

// Send new state to the channel for broadcasting
if let Err(err) = batch_events.send(*engine_state.read().await) {
tracing::warn!(
"Failed to send BatchEvent::BatchChanged({batch_max_tokens}): {err}"
)
}
}

let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);

let (min_size, max_size, prefill_token_budget) = if support_chunking {
Expand Down
64 changes: 60 additions & 4 deletions backends/v3/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ use crate::client::{
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::max;
use std::collections::VecDeque;
use text_generation_router::infer::InferError;
use std::sync::Arc;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::infer::{EngineState, InferError};
use text_generation_router::validation::{
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
ValidStoppingParameters,
};
use tokio::sync::{mpsc, oneshot};
use tokio::sync::broadcast::Sender as BroadcastSender;
use tokio::sync::{mpsc, oneshot, RwLock};
use tokio::time::Instant;
use tracing::log::warn;
use tracing::{info_span, instrument, Instrument, Span};

/// Queue entry
Expand Down Expand Up @@ -51,6 +54,8 @@ impl Queue {
speculate: u32,
max_batch_total_tokens: u32,
support_chunking: bool,
engine_state: Arc<RwLock<EngineState>>,
queue_events: BroadcastSender<EngineState>,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
Expand All @@ -64,7 +69,9 @@ impl Queue {
speculate,
max_batch_total_tokens,
support_chunking,
engine_state,
queue_receiver,
queue_events,
));

Self { queue_sender }
Expand Down Expand Up @@ -113,7 +120,7 @@ impl Queue {
}
}

// Background task responsible of the queue state
// Background task responsible for the queue state
#[allow(clippy::too_many_arguments)]
async fn queue_task(
requires_padding: bool,
Expand All @@ -123,7 +130,9 @@ async fn queue_task(
speculate: u32,
max_batch_total_tokens: u32,
support_chunking: bool,
engine_state: Arc<RwLock<EngineState>>,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
queue_events: BroadcastSender<EngineState>,
) {
let mut state = State::new(
requires_padding,
Expand All @@ -138,8 +147,29 @@ async fn queue_task(
while let Some(cmd) = receiver.recv().await {
match cmd {
QueueCommand::Append(entry, span) => {
let entry_num_tokens = entry.request.input_length;
span.in_scope(|| state.append(*entry));
metrics::gauge!("tgi_queue_size").increment(1.0);
metrics::gauge!("tgi_queue_size_tokens").increment(entry_num_tokens);

// Dispatch new state to the proxy
{
// Lock free operation (read)
let num_queued_tokens = engine_state.read().await.in_queue;
{
// Critical section, doing as little as possible here
let mut engine_state = engine_state.write().await;
engine_state.in_queue = num_queued_tokens + entry_num_tokens;
}

// Send new state to the channel for broadcasting
if let Err(err) = queue_events.send(*engine_state.read().await) {
tracing::warn!(
"Failed to send BatchEvent::QueueChanged({}): {err}",
num_queued_tokens + entry_num_tokens
)
}
}
Comment on lines +155 to +172
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Dispatch new state to the proxy
{
// Lock free operation (read)
let num_queued_tokens = engine_state.read().await.in_queue;
{
// Critical section, doing as little as possible here
let mut engine_state = engine_state.write().await;
engine_state.in_queue = num_queued_tokens + entry_num_tokens;
}
// Send new state to the channel for broadcasting
if let Err(err) = queue_events.send(*engine_state.read().await) {
tracing::warn!(
"Failed to send BatchEvent::QueueChanged({}): {err}",
num_queued_tokens + entry_num_tokens
)
}
}
engine_state.modify(|state| state.in_queue += entry_num_tokens);

Function modify then send the new state in the SSE

}
QueueCommand::NextBatch {
min_size,
Expand All @@ -154,7 +184,33 @@ async fn queue_task(
.instrument(span)
.await;
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);

{
let num_batch_tokens = state
.entries
.iter()
.map(|(_, e)| e.request.input_length)
.sum::<u32>();
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
metrics::gauge!("tgi_queue_size_tokens").set(num_batch_tokens as f64);

// Dispatch new state to the proxy
{
// Critical section, doing as little as possible here
{
let mut engine_state = engine_state.write().await;
engine_state.in_queue = num_batch_tokens;
}

// Send new state to the channel for broadcasting
if let Err(err) = queue_events.send(*engine_state.read().await) {
tracing::warn!(
"Failed to send BatchEvent::QueueChanged({}): {err}",
num_batch_tokens
)
}
}
}
}
}
}
Expand Down
19 changes: 10 additions & 9 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ serde_json = "1.0.107"
thiserror = "1.0.48"
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = [
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
] }
tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] }
Expand All @@ -46,7 +46,7 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = [
"opentelemetry-otlp",
"opentelemetry-otlp",
] }
minijinja = { workspace = true, features = ["loop_controls"] }
minijinja-contrib = { workspace = true }
Expand All @@ -57,9 +57,9 @@ image = "0.25.1"
base64 = { workspace = true }
sysinfo = "0.30.13"
uuid = { version = "1.9.1", default-features = false, features = [
"v4",
"fast-rng",
"macro-diagnostics",
"v4",
"fast-rng",
"macro-diagnostics",
] }
csv = "1.3.0"
ureq = "=2.9"
Expand All @@ -73,5 +73,6 @@ vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
[features]
default = ["ngrok"]
ngrok = ["dep:ngrok"]
engine-state = []
google = []
kserve = []
42 changes: 42 additions & 0 deletions router/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,43 @@ use serde::Serialize;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::broadcast::Receiver;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::instrument;

/// Store real-time information about batch engine usage (expressed in tokens)
#[cfg(feature = "engine-state")]
#[derive(Debug, Copy, Clone, Serialize)]
pub struct EngineState {
/// Number of tokens currently participating in current batch
pub in_flight: u32,

/// Maximum number of tokens which can participate in a batch
pub in_flight_max: u32,

/// Number of tokens currently waiting in the queue for future batching
pub in_queue: u32,

/// Maximum number of tokens which can wait in the queue for future batching
pub in_queue_max: u32,
}

#[cfg(feature = "engine-state")]
impl EngineState {
#[inline]
pub fn new(in_flight_max: u32, in_queue_max: u32) -> Self {
EngineState {
in_flight: 0,
in_flight_max,
in_queue: 0,
in_queue_max,
}
}
}

#[async_trait]
pub trait Backend {
fn schedule(
Expand All @@ -34,6 +65,11 @@ pub trait Backend {

async fn health(&self, current_health: bool) -> bool;

/// Gets a reference to receiving-side channel generating events about the current internal
/// batching engine state
#[cfg(feature = "engine-state")]
fn events(&self) -> Receiver<EngineState>;

/// The state of the health on startup
/// Typically false, or true if the backend includes
/// a warmup phase.
Expand Down Expand Up @@ -95,6 +131,12 @@ impl Infer {
}
}

#[cfg(feature = "engine-state")]
#[inline]
pub(crate) fn events(&self) -> Receiver<EngineState> {
self.backend.events()
}

/// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip_all)]
pub(crate) async fn generate_stream<'a>(
Expand Down
Loading
Loading