diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 1438f9d5a7b..9ae8303df54 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -11,7 +11,8 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor -from vllm.v1.metrics.stats import IterationStats, RequestStateStats +from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, + RequestStateStats) @dataclass @@ -26,6 +27,7 @@ class RequestState: def __init__( self, request_id: str, + lora_name: Optional[str], output_kind: RequestOutputKind, prompt: Optional[str], prompt_token_ids: List[int], @@ -36,6 +38,7 @@ def __init__( log_stats: bool, ): self.request_id = request_id + self.lora_name = lora_name self.output_kind = output_kind self.prompt = prompt self.prompt_token_ids = prompt_token_ids @@ -58,6 +61,8 @@ def from_new_request( ) -> "RequestState": return cls( request_id=request.request_id, + lora_name=(request.lora_request.name + if request.lora_request is not None else None), output_kind=request.sampling_params.output_kind, prompt=request.prompt, prompt_token_ids=request.prompt_token_ids, @@ -86,6 +91,7 @@ def __init__( self.log_stats = log_stats self.tokenizer = tokenizer self.request_states: Dict[str, RequestState] = {} + self.lora_states = LoRARequestStates() def is_request_active(self, request_id: str) -> bool: return request_id in self.request_states @@ -101,7 +107,9 @@ def abort_requests( request_ids: List[str], ) -> None: for request_id in request_ids: - self.request_states.pop(request_id, None) + req_state = self.request_states.pop(request_id, None) + if req_state is not None: + self.lora_states.abort_request(req_state) def add_request( self, @@ -112,11 +120,13 @@ def add_request( if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") - self.request_states[request_id] = RequestState.from_new_request( + req_state = RequestState.from_new_request( tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), request=request, queue=queue, log_stats=self.log_stats) + self.request_states[request_id] = req_state + self.lora_states.add_request(req_state) def process_outputs( self, @@ -214,6 +224,8 @@ def process_outputs( finish_reason, iteration_stats) + self.lora_states.update_iteration_stats(iteration_stats) + return OutputProcessorOutput( request_outputs=request_outputs, reqs_to_abort=reqs_to_abort, @@ -226,13 +238,15 @@ def _update_stats_from_output(self, req_state: RequestState, if iteration_stats is None: return + lora_stats = self.lora_states.get_stats(req_state) + assert engine_core_timestamp is not None assert req_state.stats is not None iteration_stats.update_from_output(engine_core_output, engine_core_timestamp, req_state.is_prefilling, req_state.prompt_len, - req_state.stats) + req_state.stats, lora_stats) def _update_stats_from_finished(self, req_state: RequestState, request_output: RequestOutput, @@ -246,6 +260,7 @@ def _update_stats_from_finished(self, req_state: RequestState, iteration_stats.update_from_finished_request(finish_reason, request_output, req_state.stats) + self.lora_states.finish_request(req_state) @staticmethod def _make_request_output( diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index e562b4145af..2c17da0ebc8 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -2,7 +2,7 @@ import time from abc import ABC, abstractmethod -from typing import Dict, List +from typing import Dict, List, Optional import numpy as np import prometheus_client @@ -233,6 +233,22 @@ def __init__(self, vllm_config: VllmConfig): buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) + self.gauge_lora_info: Optional[prometheus_client.Gauge] = None + if vllm_config.lora_config is not None: + self.labelname_max_lora = "max_lora" + self.labelname_waiting_lora_adapters = "waiting_lora_adapters" + self.labelname_running_lora_adapters = "running_lora_adapters" + self.max_lora = vllm_config.lora_config.max_loras + self.gauge_lora_info = \ + prometheus_client.Gauge( + name="vllm:lora_requests_info", + documentation="Running stats on lora requests.", + labelnames=[ + self.labelname_max_lora, + self.labelname_waiting_lora_adapters, + self.labelname_running_lora_adapters, + ]) + self.log_metrics_info("cache_config", vllm_config.cache_config) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): @@ -295,6 +311,19 @@ def log(self, scheduler_stats: SchedulerStats, for prefill_time in iteration_stats.prefill_times_iter: self.histogram_prefill_time_request.observe(prefill_time) + if self.gauge_lora_info is not None: + running_lora_adapters = \ + ",".join(iteration_stats.running_lora_adapters.keys()) + waiting_lora_adapters = \ + ",".join(iteration_stats.waiting_lora_adapters.keys()) + lora_info_labels = { + self.labelname_running_lora_adapters: running_lora_adapters, + self.labelname_waiting_lora_adapters: waiting_lora_adapters, + self.labelname_max_lora: self.max_lora, + } + self.gauge_lora_info.labels(**lora_info_labels)\ + .set_to_current_time() + @staticmethod def _unregister_vllm_metrics(): # Unregister any existing vLLM collectors (for CI/CD diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index a0e6204929e..74d4a1bc4fb 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -2,11 +2,12 @@ import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Dict, List, Optional, Set if TYPE_CHECKING: from vllm.outputs import RequestOutput from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason + from vllm.v1.output_processor import RequestState @dataclass @@ -36,6 +37,12 @@ class SchedulerStats: default_factory=PrefixCacheStats) +@dataclass +class LoRAStats: + waiting_requests: Set[str] = field(default_factory=set) + running_requests: Set[str] = field(default_factory=set) + + @dataclass class RequestStateStats: """Stats that need to be tracked across delta updates.""" @@ -76,6 +83,8 @@ def __init__(self): self.time_per_output_tokens_iter: List[float] = [] self.queue_times_iter: List[float] = [] self.prefill_times_iter: List[float] = [] + self.waiting_lora_adapters: Dict[str, int] = {} + self.running_lora_adapters: Dict[str, int] = {} def _time_since(self, start: float) -> float: """Calculate an interval relative to this iteration's timestamp.""" @@ -83,7 +92,8 @@ def _time_since(self, start: float) -> float: def update_from_output(self, output: "EngineCoreOutput", engine_core_timestamp: float, is_prefilling: bool, - prompt_len: int, req_stats: RequestStateStats): + prompt_len: int, req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats]): num_new_generation_tokens = len(output.new_token_ids) self.num_generation_tokens += num_new_generation_tokens @@ -105,7 +115,8 @@ def update_from_output(self, output: "EngineCoreOutput", # Process request-level engine core events if output.events is not None: - self.update_from_events(output.events, is_prefilling, req_stats) + self.update_from_events(output.request_id, output.events, + is_prefilling, req_stats, lora_stats) # Process the batch-level "new tokens" engine core event if is_prefilling: @@ -123,17 +134,21 @@ def update_from_output(self, output: "EngineCoreOutput", if num_new_generation_tokens > 0: req_stats.last_token_ts = engine_core_timestamp - def update_from_events(self, events: List["EngineCoreEvent"], - is_prefilling: bool, req_stats: RequestStateStats): + def update_from_events(self, req_id: str, events: List["EngineCoreEvent"], + is_prefilling: bool, req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats]): # Avoid circular dependency from vllm.v1.engine import EngineCoreEventType for event in events: if event.type == EngineCoreEventType.QUEUED: req_stats.queued_ts = event.timestamp + if lora_stats is not None: + lora_stats.waiting_requests.add(req_id) elif event.type == EngineCoreEventType.SCHEDULED: queued_interval = event.timestamp - req_stats.queued_ts self.queue_times_iter.append(queued_interval) req_stats.scheduled_ts = event.timestamp + LoRARequestStates.scheduled_request(lora_stats, req_id) def update_from_finished_request(self, finish_reason: "FinishReason", request_output: "RequestOutput", @@ -151,3 +166,55 @@ def update_from_finished_request(self, finish_reason: "FinishReason", inference_time=inference_time, decode_time=decode_time) self.finished_requests.append(finished_req) + + +class LoRARequestStates: + """Per-LoRA request state stats.""" + + def __init__(self): + self.lora_name_to_stats: Dict[str, LoRAStats] = {} + + def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]: + if req_state.lora_name is None: + return None + if req_state.lora_name not in self.lora_name_to_stats: + self.lora_name_to_stats[req_state.lora_name] = LoRAStats() + return self.lora_name_to_stats[req_state.lora_name] + + def add_request(self, req_state: 'RequestState'): + if (lora_stats := self.get_stats(req_state)) is not None: + lora_stats.waiting_requests.add(req_state.request_id) + + def finish_request(self, req_state: 'RequestState'): + if req_state.lora_name is None: + return + lora_stats = self.lora_name_to_stats[req_state.lora_name] + lora_stats.running_requests.remove(req_state.request_id) + + def abort_request(self, req_state: 'RequestState'): + if req_state.lora_name is None: + return + lora_stats = self.lora_name_to_stats[req_state.lora_name] + lora_stats.waiting_requests.discard(req_state.request_id) + lora_stats.running_requests.discard(req_state.request_id) + + # Break the pattern for this lifecycle methods so we can + # call this from IterationStats.update_from_events() + @staticmethod + def scheduled_request(lora_stats: Optional[LoRAStats], request_id: str): + if lora_stats is None: + return + lora_stats.waiting_requests.remove(request_id) + lora_stats.running_requests.add(request_id) + + def update_iteration_stats(self, + iteration_stats: Optional[IterationStats]): + if iteration_stats is None: + return + for lora_name, stats in self.lora_name_to_stats.items(): + if stats.waiting_requests: + iteration_stats.waiting_lora_adapters[lora_name] = \ + len(stats.waiting_requests) + if stats.running_requests: + iteration_stats.running_lora_adapters[lora_name] = \ + len(stats.running_requests)