Skip to content

Commit bc32bc7

Browse files
authored
[V1][Metrics] Implement vllm:lora_requests_info metric (#13504)
1 parent ab1091d commit bc32bc7

File tree

3 files changed

+121
-10
lines changed

3 files changed

+121
-10
lines changed

vllm/v1/engine/output_processor.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
1212
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
1313
from vllm.v1.engine.logprobs import LogprobsProcessor
14-
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
14+
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
15+
RequestStateStats)
1516

1617

1718
@dataclass
@@ -26,6 +27,7 @@ class RequestState:
2627
def __init__(
2728
self,
2829
request_id: str,
30+
lora_name: Optional[str],
2931
output_kind: RequestOutputKind,
3032
prompt: Optional[str],
3133
prompt_token_ids: List[int],
@@ -36,6 +38,7 @@ def __init__(
3638
log_stats: bool,
3739
):
3840
self.request_id = request_id
41+
self.lora_name = lora_name
3942
self.output_kind = output_kind
4043
self.prompt = prompt
4144
self.prompt_token_ids = prompt_token_ids
@@ -58,6 +61,8 @@ def from_new_request(
5861
) -> "RequestState":
5962
return cls(
6063
request_id=request.request_id,
64+
lora_name=(request.lora_request.name
65+
if request.lora_request is not None else None),
6166
output_kind=request.sampling_params.output_kind,
6267
prompt=request.prompt,
6368
prompt_token_ids=request.prompt_token_ids,
@@ -86,6 +91,7 @@ def __init__(
8691
self.log_stats = log_stats
8792
self.tokenizer = tokenizer
8893
self.request_states: Dict[str, RequestState] = {}
94+
self.lora_states = LoRARequestStates()
8995

9096
def is_request_active(self, request_id: str) -> bool:
9197
return request_id in self.request_states
@@ -101,7 +107,9 @@ def abort_requests(
101107
request_ids: List[str],
102108
) -> None:
103109
for request_id in request_ids:
104-
self.request_states.pop(request_id, None)
110+
req_state = self.request_states.pop(request_id, None)
111+
if req_state is not None:
112+
self.lora_states.abort_request(req_state)
105113

106114
def add_request(
107115
self,
@@ -112,11 +120,13 @@ def add_request(
112120
if request_id in self.request_states:
113121
raise ValueError(f"Request id {request_id} already running.")
114122

115-
self.request_states[request_id] = RequestState.from_new_request(
123+
req_state = RequestState.from_new_request(
116124
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
117125
request=request,
118126
queue=queue,
119127
log_stats=self.log_stats)
128+
self.request_states[request_id] = req_state
129+
self.lora_states.add_request(req_state)
120130

121131
def process_outputs(
122132
self,
@@ -214,6 +224,8 @@ def process_outputs(
214224
finish_reason,
215225
iteration_stats)
216226

227+
self.lora_states.update_iteration_stats(iteration_stats)
228+
217229
return OutputProcessorOutput(
218230
request_outputs=request_outputs,
219231
reqs_to_abort=reqs_to_abort,
@@ -226,13 +238,15 @@ def _update_stats_from_output(self, req_state: RequestState,
226238
if iteration_stats is None:
227239
return
228240

241+
lora_stats = self.lora_states.get_stats(req_state)
242+
229243
assert engine_core_timestamp is not None
230244
assert req_state.stats is not None
231245
iteration_stats.update_from_output(engine_core_output,
232246
engine_core_timestamp,
233247
req_state.is_prefilling,
234248
req_state.prompt_len,
235-
req_state.stats)
249+
req_state.stats, lora_stats)
236250

237251
def _update_stats_from_finished(self, req_state: RequestState,
238252
request_output: RequestOutput,
@@ -246,6 +260,7 @@ def _update_stats_from_finished(self, req_state: RequestState,
246260
iteration_stats.update_from_finished_request(finish_reason,
247261
request_output,
248262
req_state.stats)
263+
self.lora_states.finish_request(req_state)
249264

250265
@staticmethod
251266
def _make_request_output(

vllm/v1/metrics/loggers.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import time
44
from abc import ABC, abstractmethod
5-
from typing import Dict, List
5+
from typing import Dict, List, Optional
66

77
import numpy as np
88
import prometheus_client
@@ -233,6 +233,22 @@ def __init__(self, vllm_config: VllmConfig):
233233
buckets=request_latency_buckets,
234234
labelnames=labelnames).labels(*labelvalues)
235235

236+
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
237+
if vllm_config.lora_config is not None:
238+
self.labelname_max_lora = "max_lora"
239+
self.labelname_waiting_lora_adapters = "waiting_lora_adapters"
240+
self.labelname_running_lora_adapters = "running_lora_adapters"
241+
self.max_lora = vllm_config.lora_config.max_loras
242+
self.gauge_lora_info = \
243+
prometheus_client.Gauge(
244+
name="vllm:lora_requests_info",
245+
documentation="Running stats on lora requests.",
246+
labelnames=[
247+
self.labelname_max_lora,
248+
self.labelname_waiting_lora_adapters,
249+
self.labelname_running_lora_adapters,
250+
])
251+
236252
self.log_metrics_info("cache_config", vllm_config.cache_config)
237253

238254
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
@@ -295,6 +311,19 @@ def log(self, scheduler_stats: SchedulerStats,
295311
for prefill_time in iteration_stats.prefill_times_iter:
296312
self.histogram_prefill_time_request.observe(prefill_time)
297313

314+
if self.gauge_lora_info is not None:
315+
running_lora_adapters = \
316+
",".join(iteration_stats.running_lora_adapters.keys())
317+
waiting_lora_adapters = \
318+
",".join(iteration_stats.waiting_lora_adapters.keys())
319+
lora_info_labels = {
320+
self.labelname_running_lora_adapters: running_lora_adapters,
321+
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
322+
self.labelname_max_lora: self.max_lora,
323+
}
324+
self.gauge_lora_info.labels(**lora_info_labels)\
325+
.set_to_current_time()
326+
298327
@staticmethod
299328
def _unregister_vllm_metrics():
300329
# Unregister any existing vLLM collectors (for CI/CD

vllm/v1/metrics/stats.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
import time
44
from dataclasses import dataclass, field
5-
from typing import TYPE_CHECKING, List
5+
from typing import TYPE_CHECKING, Dict, List, Optional, Set
66

77
if TYPE_CHECKING:
88
from vllm.outputs import RequestOutput
99
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
10+
from vllm.v1.output_processor import RequestState
1011

1112

1213
@dataclass
@@ -36,6 +37,12 @@ class SchedulerStats:
3637
default_factory=PrefixCacheStats)
3738

3839

40+
@dataclass
41+
class LoRAStats:
42+
waiting_requests: Set[str] = field(default_factory=set)
43+
running_requests: Set[str] = field(default_factory=set)
44+
45+
3946
@dataclass
4047
class RequestStateStats:
4148
"""Stats that need to be tracked across delta updates."""
@@ -76,14 +83,17 @@ def __init__(self):
7683
self.time_per_output_tokens_iter: List[float] = []
7784
self.queue_times_iter: List[float] = []
7885
self.prefill_times_iter: List[float] = []
86+
self.waiting_lora_adapters: Dict[str, int] = {}
87+
self.running_lora_adapters: Dict[str, int] = {}
7988

8089
def _time_since(self, start: float) -> float:
8190
"""Calculate an interval relative to this iteration's timestamp."""
8291
return self.iteration_timestamp - start
8392

8493
def update_from_output(self, output: "EngineCoreOutput",
8594
engine_core_timestamp: float, is_prefilling: bool,
86-
prompt_len: int, req_stats: RequestStateStats):
95+
prompt_len: int, req_stats: RequestStateStats,
96+
lora_stats: Optional[LoRAStats]):
8797
num_new_generation_tokens = len(output.new_token_ids)
8898

8999
self.num_generation_tokens += num_new_generation_tokens
@@ -105,7 +115,8 @@ def update_from_output(self, output: "EngineCoreOutput",
105115

106116
# Process request-level engine core events
107117
if output.events is not None:
108-
self.update_from_events(output.events, is_prefilling, req_stats)
118+
self.update_from_events(output.request_id, output.events,
119+
is_prefilling, req_stats, lora_stats)
109120

110121
# Process the batch-level "new tokens" engine core event
111122
if is_prefilling:
@@ -123,17 +134,21 @@ def update_from_output(self, output: "EngineCoreOutput",
123134
if num_new_generation_tokens > 0:
124135
req_stats.last_token_ts = engine_core_timestamp
125136

126-
def update_from_events(self, events: List["EngineCoreEvent"],
127-
is_prefilling: bool, req_stats: RequestStateStats):
137+
def update_from_events(self, req_id: str, events: List["EngineCoreEvent"],
138+
is_prefilling: bool, req_stats: RequestStateStats,
139+
lora_stats: Optional[LoRAStats]):
128140
# Avoid circular dependency
129141
from vllm.v1.engine import EngineCoreEventType
130142
for event in events:
131143
if event.type == EngineCoreEventType.QUEUED:
132144
req_stats.queued_ts = event.timestamp
145+
if lora_stats is not None:
146+
lora_stats.waiting_requests.add(req_id)
133147
elif event.type == EngineCoreEventType.SCHEDULED:
134148
queued_interval = event.timestamp - req_stats.queued_ts
135149
self.queue_times_iter.append(queued_interval)
136150
req_stats.scheduled_ts = event.timestamp
151+
LoRARequestStates.scheduled_request(lora_stats, req_id)
137152

138153
def update_from_finished_request(self, finish_reason: "FinishReason",
139154
request_output: "RequestOutput",
@@ -151,3 +166,55 @@ def update_from_finished_request(self, finish_reason: "FinishReason",
151166
inference_time=inference_time,
152167
decode_time=decode_time)
153168
self.finished_requests.append(finished_req)
169+
170+
171+
class LoRARequestStates:
172+
"""Per-LoRA request state stats."""
173+
174+
def __init__(self):
175+
self.lora_name_to_stats: Dict[str, LoRAStats] = {}
176+
177+
def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]:
178+
if req_state.lora_name is None:
179+
return None
180+
if req_state.lora_name not in self.lora_name_to_stats:
181+
self.lora_name_to_stats[req_state.lora_name] = LoRAStats()
182+
return self.lora_name_to_stats[req_state.lora_name]
183+
184+
def add_request(self, req_state: 'RequestState'):
185+
if (lora_stats := self.get_stats(req_state)) is not None:
186+
lora_stats.waiting_requests.add(req_state.request_id)
187+
188+
def finish_request(self, req_state: 'RequestState'):
189+
if req_state.lora_name is None:
190+
return
191+
lora_stats = self.lora_name_to_stats[req_state.lora_name]
192+
lora_stats.running_requests.remove(req_state.request_id)
193+
194+
def abort_request(self, req_state: 'RequestState'):
195+
if req_state.lora_name is None:
196+
return
197+
lora_stats = self.lora_name_to_stats[req_state.lora_name]
198+
lora_stats.waiting_requests.discard(req_state.request_id)
199+
lora_stats.running_requests.discard(req_state.request_id)
200+
201+
# Break the pattern for this lifecycle methods so we can
202+
# call this from IterationStats.update_from_events()
203+
@staticmethod
204+
def scheduled_request(lora_stats: Optional[LoRAStats], request_id: str):
205+
if lora_stats is None:
206+
return
207+
lora_stats.waiting_requests.remove(request_id)
208+
lora_stats.running_requests.add(request_id)
209+
210+
def update_iteration_stats(self,
211+
iteration_stats: Optional[IterationStats]):
212+
if iteration_stats is None:
213+
return
214+
for lora_name, stats in self.lora_name_to_stats.items():
215+
if stats.waiting_requests:
216+
iteration_stats.waiting_lora_adapters[lora_name] = \
217+
len(stats.waiting_requests)
218+
if stats.running_requests:
219+
iteration_stats.running_lora_adapters[lora_name] = \
220+
len(stats.running_requests)

0 commit comments

Comments
 (0)