Skip to content

Commit 8b31770

Browse files
committed
[V1][Metrics] Add e2e/queue/prefill/decode/inference time histograms
Follow on from vllm-project#12579, part of vllm-project#10582. Add the following: - vllm:e2e_request_latency_seconds - vllm:request_queue_time_seconds - vllm:request_inference_time_seconds - vllm:request_prefill_time_seconds - vllm:request_decode_time_seconds e2e_request_latency is calculated relative to the arrival_time timestamp recorded by the frontend. For the rest ... we want to capture (in histograms) precise pre-request timing intervals between certain events in the engine core: ``` << queued timestamp >> [ queue interval ] << scheduled timestamp >> [ prefill interval ] << new token timestamp (FIRST) >> [ inter-token interval ] << new token timestamp >> [ decode interval (relative to first token time) [ inference interval (relative to scheduled time) << new token timestamp (FINISHED) >> ``` We want to collect these metrics in the frontend process, to keep the engine core freed up as much as possible. We need to calculate these intervals based on timestamps recorded by the engine core. Engine core will include these timestamps in EngineCoreOutput (per request) as a sequence of timestamped events, and the frontend will calculate intervals and log them. Where we record these timestamped events: - QUEUED: scheduler add_request() - SCHEDULED: scheduler schedule() There is an implicit NEW_TOKENS timestamp based on an initialization timestamp recorded on EngineCoreOutputs. Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 58047c6 commit 8b31770

File tree

9 files changed

+241
-50
lines changed

9 files changed

+241
-50
lines changed

tests/entrypoints/openai/test_metrics.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ async def client(server):
8585
"vllm:time_per_output_token_seconds":
8686
[("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))],
8787
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
88+
"vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)],
89+
"vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)],
90+
"vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)],
91+
"vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)],
8892
"vllm:request_prompt_tokens":
8993
[("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
9094
("_count", _NUM_REQUESTS)],
@@ -169,6 +173,18 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
169173
"vllm:e2e_request_latency_seconds_sum",
170174
"vllm:e2e_request_latency_seconds_bucket",
171175
"vllm:e2e_request_latency_seconds_count",
176+
"vllm:request_queue_time_seconds_sum",
177+
"vllm:request_queue_time_seconds_bucket",
178+
"vllm:request_queue_time_seconds_count",
179+
"vllm:request_inference_time_seconds_sum",
180+
"vllm:request_inference_time_seconds_bucket",
181+
"vllm:request_inference_time_seconds_count",
182+
"vllm:request_prefill_time_seconds_sum",
183+
"vllm:request_prefill_time_seconds_bucket",
184+
"vllm:request_prefill_time_seconds_count",
185+
"vllm:request_decode_time_seconds_sum",
186+
"vllm:request_decode_time_seconds_bucket",
187+
"vllm:request_decode_time_seconds_count",
172188
"vllm:request_prompt_tokens_sum",
173189
"vllm:request_prompt_tokens_bucket",
174190
"vllm:request_prompt_tokens_count",
@@ -218,6 +234,21 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
218234
"vllm:time_per_output_token_seconds_sum",
219235
"vllm:time_per_output_token_seconds_bucket",
220236
"vllm:time_per_output_token_seconds_count",
237+
"vllm:e2e_request_latency_seconds_sum",
238+
"vllm:e2e_request_latency_seconds_bucket",
239+
"vllm:e2e_request_latency_seconds_count",
240+
"vllm:request_queue_time_seconds_sum",
241+
"vllm:request_queue_time_seconds_bucket",
242+
"vllm:request_queue_time_seconds_count",
243+
"vllm:request_inference_time_seconds_sum",
244+
"vllm:request_inference_time_seconds_bucket",
245+
"vllm:request_inference_time_seconds_count",
246+
"vllm:request_prefill_time_seconds_sum",
247+
"vllm:request_prefill_time_seconds_bucket",
248+
"vllm:request_prefill_time_seconds_count",
249+
"vllm:request_decode_time_seconds_sum",
250+
"vllm:request_decode_time_seconds_bucket",
251+
"vllm:request_decode_time_seconds_count",
221252
]
222253

223254

tests/v1/engine/test_output_processor.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import math
4+
import time
45
from typing import Dict, List, Optional
56

67
import pytest
@@ -15,6 +16,7 @@
1516
from vllm.transformers_utils.tokenizer import AnyTokenizer
1617
from vllm.v1.engine import EngineCoreRequest
1718
from vllm.v1.engine.output_processor import OutputProcessor
19+
from vllm.v1.metrics.stats import IterationStats
1820

1921

2022
def _ref_convert_id_to_token(
@@ -603,6 +605,7 @@ def test_iteration_stats(dummy_test_vectors):
603605
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
604606
log_stats=True)
605607
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
608+
engine_core_timestamp = time.monotonic()
606609

607610
# Make N requests.
608611
requests = [
@@ -630,8 +633,9 @@ def test_iteration_stats(dummy_test_vectors):
630633

631634
# First iteration has 2 prefills.
632635
outputs = engine_core.get_outputs()[:num_active]
633-
processed_outputs = output_processor.process_outputs(outputs)
634-
iteration_stats = processed_outputs.iteration_stats
636+
iteration_stats = IterationStats(output_processor.log_stats)
637+
output_processor.process_outputs(outputs, engine_core_timestamp,
638+
iteration_stats)
635639
total_prompt_tokens = sum([
636640
len(prompt_tokens)
637641
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
@@ -642,8 +646,9 @@ def test_iteration_stats(dummy_test_vectors):
642646

643647
# Just decodes in this step.
644648
outputs = engine_core.get_outputs()[:num_active]
645-
processed_outputs = output_processor.process_outputs(outputs)
646-
iteration_stats = processed_outputs.iteration_stats
649+
iteration_stats = IterationStats(output_processor.log_stats)
650+
output_processor.process_outputs(outputs, engine_core_timestamp,
651+
iteration_stats)
647652

648653
assert iteration_stats.num_prompt_tokens == 0
649654
assert iteration_stats.num_generation_tokens == num_active
@@ -652,17 +657,19 @@ def test_iteration_stats(dummy_test_vectors):
652657
output_processor.add_request(inactive_request)
653658
num_active += 1
654659
outputs = engine_core.get_outputs()[:num_active]
655-
processed_outputs = output_processor.process_outputs(outputs)
656-
iteration_stats = processed_outputs.iteration_stats
660+
iteration_stats = IterationStats(output_processor.log_stats)
661+
output_processor.process_outputs(outputs, engine_core_timestamp,
662+
iteration_stats)
657663
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
658664

659665
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
660666
assert iteration_stats.num_generation_tokens == num_active
661667

662668
# Just decodes in this step.
663669
outputs = engine_core.get_outputs()[:num_active]
664-
processed_outputs = output_processor.process_outputs(outputs)
665-
iteration_stats = processed_outputs.iteration_stats
670+
iteration_stats = IterationStats(output_processor.log_stats)
671+
output_processor.process_outputs(outputs, engine_core_timestamp,
672+
iteration_stats)
666673

667674
assert iteration_stats.num_prompt_tokens == 0
668675
assert iteration_stats.num_generation_tokens == num_active

vllm/v1/core/scheduler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import time
34
from collections import deque
45
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
56

@@ -107,6 +108,8 @@ def schedule(self) -> "SchedulerOutput":
107108
scheduled_encoder_inputs: Dict[str, List[int]] = {}
108109
encoder_budget = self.max_num_encoder_input_tokens
109110

111+
scheduled_timestamp = time.monotonic()
112+
110113
# First, schedule the RUNNING requests.
111114
req_index = 0
112115
while req_index < len(self.running) and token_budget > 0:
@@ -246,6 +249,7 @@ def schedule(self) -> "SchedulerOutput":
246249
self.running.append(request)
247250
if request.status == RequestStatus.WAITING:
248251
scheduled_new_reqs.append(request)
252+
request.scheduled(scheduled_timestamp)
249253
elif request.status == RequestStatus.PREEMPTED:
250254
scheduled_resumed_reqs.append(request)
251255
else:
@@ -508,7 +512,8 @@ def update_from_output(
508512
finish_reason=request.get_finished_reason(),
509513
new_logprobs=new_logprobs,
510514
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
511-
stop_reason=request.stop_reason))
515+
stop_reason=request.stop_reason,
516+
events=request.take_events()))
512517

513518
if not stopped:
514519
new_running.append(request)
@@ -541,6 +546,7 @@ def _check_stop(self, request: Request) -> bool:
541546
def add_request(self, request: Request) -> None:
542547
self.waiting.append(request)
543548
self.requests[request.request_id] = request
549+
request.queued()
544550

545551
def finish_requests(
546552
self,

vllm/v1/engine/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import enum
4+
import time
45
from typing import List, Optional, Union
56

67
import msgspec
@@ -60,6 +61,30 @@ class EngineCoreRequest(
6061
lora_request: Optional[LoRARequest]
6162

6263

64+
class EngineCoreEventType(enum.IntEnum):
65+
"""The type of engine core request event."""
66+
QUEUED = 1
67+
SCHEDULED = 2
68+
69+
70+
class EngineCoreEvent(msgspec.Struct):
71+
"""A timestamped engine core event associated with a request.
72+
73+
The timestamp is a monotonic timestamps and is used for by the engine
74+
frontend to calculate intervals between engine core events. These
75+
timestamps should not be compared with timestamps from other processes.
76+
"""
77+
type: EngineCoreEventType
78+
timestamp: float
79+
80+
@classmethod
81+
def new_event(cls,
82+
event_type: EngineCoreEventType,
83+
timestamp: Optional[float] = None) -> "EngineCoreEvent":
84+
timestamp = time.monotonic() if timestamp is None else timestamp
85+
return cls(event_type, timestamp)
86+
87+
6388
class EngineCoreOutput(
6489
msgspec.Struct,
6590
array_like=True, # type: ignore[call-arg]
@@ -74,6 +99,7 @@ class EngineCoreOutput(
7499

75100
finish_reason: Optional[FinishReason] = None
76101
stop_reason: Union[int, str, None] = None
102+
events: Optional[List[EngineCoreEvent]] = None
77103

78104
@property
79105
def finished(self) -> bool:
@@ -92,6 +118,11 @@ class EngineCoreOutputs(
92118
# [num_reqs]
93119
outputs: List[EngineCoreOutput]
94120
scheduler_stats: SchedulerStats
121+
timestamp: float = 0.0
122+
123+
def __post_init__(self):
124+
if self.timestamp == 0.0:
125+
self.timestamp = time.monotonic()
95126

96127

97128
class EngineCoreRequestType(enum.Enum):

vllm/v1/engine/async_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ async def _run_output_handler(self):
246246
# 1) Pull EngineCoreOutputs from the EngineCore.
247247
outputs = await self.engine_core.get_output_async()
248248

249+
iteration_stats = IterationStats(self.log_stats)
250+
249251
# Split outputs into chunks of at most
250252
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
251253
# event loop for too long.
@@ -257,14 +259,12 @@ async def _run_output_handler(self):
257259
outputs.outputs,
258260
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
259261

260-
iteration_stats = None
261262
for i, outputs_slice in enumerate(slices):
262263
# 2) Process EngineCoreOutputs.
263264
processed_outputs = self.output_processor.process_outputs(
264-
outputs_slice, iteration_stats)
265+
outputs_slice, outputs.timestamp, iteration_stats)
265266
# NOTE: RequestOutputs are pushed to their queues.
266267
assert not processed_outputs.request_outputs
267-
iteration_stats = processed_outputs.iteration_stats
268268

269269
# Allow other asyncio tasks to run between chunks
270270
if i + 1 < len(slices):

vllm/v1/engine/output_processor.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ class OutputProcessorOutput:
1919

2020
request_outputs: List[RequestOutput]
2121
reqs_to_abort: List[str]
22-
iteration_stats: IterationStats
2322

2423

2524
class RequestState:
@@ -45,7 +44,7 @@ def __init__(
4544
self.is_prefilling = True
4645
self.queue = queue
4746

48-
self.stats = RequestStateStats(last_token_time=arrival_time)
47+
self.stats = RequestStateStats(arrival_time=arrival_time)
4948

5049
@classmethod
5150
def from_new_request(
@@ -117,6 +116,7 @@ def add_request(
117116
def process_outputs(
118117
self,
119118
engine_core_outputs: List[EngineCoreOutput],
119+
engine_core_timestamp: Optional[float] = None,
120120
iteration_stats: Optional[IterationStats] = None,
121121
) -> OutputProcessorOutput:
122122
"""
@@ -145,8 +145,6 @@ def process_outputs(
145145

146146
request_outputs: List[RequestOutput] = []
147147
reqs_to_abort: List[str] = []
148-
if not iteration_stats:
149-
iteration_stats = IterationStats(self.log_stats)
150148
for engine_core_output in engine_core_outputs:
151149
req_id = engine_core_output.request_id
152150
req_state = self.request_states.get(req_id)
@@ -155,10 +153,13 @@ def process_outputs(
155153
continue
156154

157155
# 1) Compute stats for this iteration.
158-
iteration_stats.update_from_output(engine_core_output,
159-
req_state.is_prefilling,
160-
req_state.prompt_len,
161-
req_state.stats)
156+
if iteration_stats is not None:
157+
assert engine_core_timestamp is not None
158+
iteration_stats.update_from_output(engine_core_output,
159+
engine_core_timestamp,
160+
req_state.is_prefilling,
161+
req_state.prompt_len,
162+
req_state.stats)
162163

163164
new_token_ids = engine_core_output.new_token_ids
164165
finish_reason = engine_core_output.finish_reason
@@ -205,15 +206,15 @@ def process_outputs(
205206
# detected stop string, abort needed in EngineCore.
206207
reqs_to_abort.append(req_id)
207208

208-
# Track per-request stats.
209-
assert finish_reason is not None
210-
iteration_stats.update_from_finished_request(
211-
finish_reason, request_output, req_state.stats)
209+
# Track per-request stats
210+
if iteration_stats is not None:
211+
assert finish_reason is not None
212+
iteration_stats.update_from_finished_request(
213+
finish_reason, request_output, req_state.stats)
212214

213215
return OutputProcessorOutput(
214216
request_outputs=request_outputs,
215217
reqs_to_abort=reqs_to_abort,
216-
iteration_stats=iteration_stats,
217218
)
218219

219220
@staticmethod

vllm/v1/metrics/loggers.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,45 @@ def __init__(self, model_config: ModelConfig):
162162
],
163163
labelnames=labelnames).labels(*labelvalues)
164164

165+
request_latency_buckets = [
166+
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
167+
40.0, 50.0, 60.0
168+
]
169+
self.histogram_e2e_time_request = \
170+
prometheus_client.Histogram(
171+
name="vllm:e2e_request_latency_seconds",
172+
documentation="Histogram of e2e request latency in seconds.",
173+
buckets=request_latency_buckets,
174+
labelnames=labelnames).labels(*labelvalues)
175+
self.histogram_queue_time_request = \
176+
prometheus_client.Histogram(
177+
name="vllm:request_queue_time_seconds",
178+
documentation=
179+
"Histogram of time spent in WAITING phase for request.",
180+
buckets=request_latency_buckets,
181+
labelnames=labelnames).labels(*labelvalues)
182+
self.histogram_inference_time_request = \
183+
prometheus_client.Histogram(
184+
name="vllm:request_inference_time_seconds",
185+
documentation=
186+
"Histogram of time spent in RUNNING phase for request.",
187+
buckets=request_latency_buckets,
188+
labelnames=labelnames).labels(*labelvalues)
189+
self.histogram_prefill_time_request = \
190+
prometheus_client.Histogram(
191+
name="vllm:request_prefill_time_seconds",
192+
documentation=
193+
"Histogram of time spent in PREFILL phase for request.",
194+
buckets=request_latency_buckets,
195+
labelnames=labelnames).labels(*labelvalues)
196+
self.histogram_decode_time_request = \
197+
prometheus_client.Histogram(
198+
name="vllm:request_decode_time_seconds",
199+
documentation=
200+
"Histogram of time spent in DECODE phase for request.",
201+
buckets=request_latency_buckets,
202+
labelnames=labelnames).labels(*labelvalues)
203+
165204
def log(self, scheduler_stats: SchedulerStats,
166205
iteration_stats: IterationStats):
167206
"""Log to prometheus."""
@@ -176,6 +215,12 @@ def log(self, scheduler_stats: SchedulerStats,
176215

177216
for finished_request in iteration_stats.finished_requests:
178217
self.counter_request_success[finished_request.finish_reason].inc()
218+
self.histogram_e2e_time_request.observe(
219+
finished_request.e2e_latency)
220+
self.histogram_inference_time_request.observe(
221+
finished_request.inference_time)
222+
self.histogram_decode_time_request.observe(
223+
finished_request.decode_time)
179224
self.histogram_num_prompt_tokens_request.observe(
180225
finished_request.num_prompt_tokens)
181226
self.histogram_num_generation_tokens_request.observe(
@@ -185,6 +230,10 @@ def log(self, scheduler_stats: SchedulerStats,
185230
self.histogram_time_to_first_token.observe(ttft)
186231
for tpot in iteration_stats.time_per_output_tokens_iter:
187232
self.histogram_time_per_output_token.observe(tpot)
233+
for queue_time in iteration_stats.queue_times_iter:
234+
self.histogram_queue_time_request.observe(queue_time)
235+
for prefill_time in iteration_stats.prefill_times_iter:
236+
self.histogram_prefill_time_request.observe(prefill_time)
188237

189238
@staticmethod
190239
def _unregister_vllm_metrics():

0 commit comments

Comments
 (0)