Skip to content

Commit 209d131

Browse files
committed
[WIP][V1][Metrics] Speculative decoding metrics
Fixes vllm-project#13990, part of vllm-project#10582 Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 7ffcccf commit 209d131

File tree

5 files changed

+156
-4
lines changed

5 files changed

+156
-4
lines changed

vllm/v1/core/sched/scheduler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from vllm.v1.metrics.stats import SchedulerStats
2323
from vllm.v1.outputs import ModelRunnerOutput
2424
from vllm.v1.request import Request, RequestStatus
25+
from vllm.v1.spec_decode.metrics import SpecDecodingStats
2526
from vllm.v1.structured_output import StructuredOutputManager
2627

2728
logger = init_logger(__name__)
@@ -535,6 +536,7 @@ def update_from_output(
535536
spec_token_ids = model_runner_output.spec_token_ids
536537
logprobs = model_runner_output.logprobs
537538
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
539+
spec_decoding_stats = SpecDecodingStats()
538540
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
539541

540542
new_running: list[Request] = []
@@ -559,6 +561,7 @@ def update_from_output(
559561
# Otherwise, we ignore the sampler output for the request.
560562
request.num_computed_tokens += num_tokens_scheduled
561563
assert request.num_computed_tokens <= request.num_tokens
564+
spec_decoding_stats.num_emitted_tokens += num_tokens_scheduled
562565
else:
563566
# num_computed_tokens_step represents the number of tokens
564567
# processed in the current step, considering scheduled
@@ -576,6 +579,13 @@ def update_from_output(
576579
len(generated_token_ids))
577580
request.num_computed_tokens += num_computed_tokens_step
578581

582+
spec_decoding_stats.num_draft_tokens += len(
583+
scheduled_spec_token_ids)
584+
spec_decoding_stats.num_accepted_tokens += len(
585+
generated_token_ids) - 1
586+
spec_decoding_stats.num_emitted_tokens += \
587+
num_computed_tokens_step
588+
579589
cached_encoder_input_ids = (
580590
self.encoder_cache_manager.get_cached_input_ids(request))
581591
# OPTIMIZATION: Avoid list(set) if the set is empty.
@@ -647,7 +657,7 @@ def update_from_output(
647657
self.running = new_running
648658
return EngineCoreOutputs(
649659
outputs=outputs,
650-
scheduler_stats=self.make_stats(),
660+
scheduler_stats=self.make_stats(spec_decoding_stats),
651661
)
652662

653663
def add_request(self, request: Request) -> None:
@@ -708,12 +718,16 @@ def get_num_unscheduled_requests(self) -> int:
708718
def reset_prefix_cache(self) -> bool:
709719
return self.kv_cache_manager.reset_prefix_cache()
710720

711-
def make_stats(self) -> Optional[SchedulerStats]:
721+
def make_stats(
722+
self,
723+
spec_decoding_stats: Optional[SpecDecodingStats] = None,
724+
) -> Optional[SchedulerStats]:
712725
if not self.log_stats:
713726
return None
714727
return SchedulerStats(
715728
num_running_reqs=len(self.running),
716729
num_waiting_reqs=len(self.waiting),
717730
gpu_cache_usage=self.kv_cache_manager.usage,
718731
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
732+
spec_decoding_stats=spec_decoding_stats,
719733
)

vllm/v1/engine/async_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
self.stat_loggers: list[StatLoggerBase] = []
6969
if self.log_stats:
7070
if logger.isEnabledFor(logging.INFO):
71-
self.stat_loggers.append(LoggingStatLogger())
71+
self.stat_loggers.append(LoggingStatLogger(vllm_config))
7272
self.stat_loggers.append(PrometheusStatLogger(vllm_config))
7373

7474
# Tokenizer (+ ensure liveness if running in another process).

vllm/v1/metrics/loggers.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
1313
from vllm.v1.engine import FinishReason
1414
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
15+
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
1516

1617
logger = init_logger(__name__)
1718

@@ -31,12 +32,14 @@ def log(self): # noqa
3132

3233
class LoggingStatLogger(StatLoggerBase):
3334

34-
def __init__(self):
35+
def __init__(self, vllm_config: VllmConfig):
3536
self._reset(time.monotonic())
3637
self.last_scheduler_stats = SchedulerStats()
3738
# Prefix cache metrics. This cannot be reset.
3839
# TODO: Make the interval configurable.
3940
self.prefix_caching_metrics = PrefixCachingMetrics()
41+
self.spec_decoding_metrics = SpecDecodingMetrics(
42+
vllm_config.speculative_config)
4043

4144
def _reset(self, now):
4245
self.last_log_time = now
@@ -64,6 +67,10 @@ def record(self, scheduler_stats: SchedulerStats,
6467

6568
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
6669

70+
if scheduler_stats.spec_decoding_stats is not None:
71+
self.spec_decoding_metrics.observe(
72+
scheduler_stats.spec_decoding_stats)
73+
6774
self.last_scheduler_stats = scheduler_stats
6875

6976
def log(self):
@@ -91,6 +98,9 @@ def log(self):
9198
self.prefix_caching_metrics.hit_rate * 100,
9299
)
93100

101+
if scheduler_stats.spec_decoding_stats is not None:
102+
self.spec_decoding_metrics.log()
103+
94104

95105
class PrometheusStatLogger(StatLoggerBase):
96106

@@ -296,6 +306,26 @@ def __init__(self, vllm_config: VllmConfig):
296306
self.labelname_running_lora_adapters,
297307
])
298308

309+
#
310+
# Speculative Decoding metrics
311+
# FIXME: add note on acceptance rate and system efficiency
312+
#
313+
self.counter_spec_decode_num_draft_tokens = \
314+
prometheus_client.Counter(
315+
name="vllm:spec_decode_num_draft_tokens_total",
316+
documentation="Number of draft tokens.",
317+
labelnames=labelnames).labels(*labelvalues)
318+
self.counter_spec_decode_num_accepted_tokens = \
319+
prometheus_client.Counter(
320+
name="vllm:spec_decode_num_accepted_tokens_total",
321+
documentation="Number of accepted tokens.",
322+
labelnames=labelnames).labels(*labelvalues)
323+
self.counter_spec_decode_num_emitted_tokens = \
324+
prometheus_client.Counter(
325+
name="vllm:spec_decode_num_emitted_tokens_total",
326+
documentation="Number of emitted tokens.",
327+
labelnames=labelnames).labels(*labelvalues)
328+
299329
#
300330
# Cache config info metric
301331
#
@@ -332,6 +362,14 @@ def record(self, scheduler_stats: SchedulerStats,
332362
self.counter_gpu_prefix_cache_hits.inc(
333363
scheduler_stats.prefix_cache_stats.hits)
334364

365+
if scheduler_stats.spec_decoding_stats is not None:
366+
self.counter_spec_decode_num_draft_tokens.inc(
367+
scheduler_stats.spec_decoding_stats.num_draft_tokens)
368+
self.counter_spec_decode_num_accepted_tokens.inc(
369+
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
370+
self.counter_spec_decode_num_emitted_tokens.inc(
371+
scheduler_stats.spec_decoding_stats.num_emitted_tokens)
372+
335373
if iteration_stats is None:
336374
return
337375

vllm/v1/metrics/stats.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from dataclasses import dataclass, field
55
from typing import TYPE_CHECKING, Optional
66

7+
from vllm.v1.spec_decode.metrics import SpecDecodingStats
8+
79
if TYPE_CHECKING:
810
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
911
from vllm.v1.engine.output_processor import RequestState
@@ -35,6 +37,8 @@ class SchedulerStats:
3537
prefix_cache_stats: PrefixCacheStats = field(
3638
default_factory=PrefixCacheStats)
3739

40+
spec_decoding_stats: Optional[SpecDecodingStats] = None
41+
3842

3943
@dataclass
4044
class LoRAStats:

vllm/v1/spec_decode/metrics.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from dataclasses import dataclass
4+
5+
import numpy as np
6+
7+
from vllm.config import SpeculativeConfig
8+
from vllm.logger import init_logger
9+
10+
logger = init_logger(__name__)
11+
12+
13+
@dataclass
14+
class SpecDecodingStats:
15+
num_draft_tokens: int = 0
16+
num_accepted_tokens: int = 0
17+
num_emitted_tokens: int = 0
18+
19+
def take(self):
20+
copied = SpecDecodingStats(self.num_draft_tokens,
21+
self.num_accepted_tokens,
22+
self.num_emitted_tokens)
23+
self.reset()
24+
return copied
25+
26+
def reset(self):
27+
self.num_draft_tokens = 0
28+
self.num_accepted_tokens = 0
29+
self.num_emitted_tokens = 0
30+
31+
32+
class SpecDecodingMetrics:
33+
34+
def __init__(self, speculative_config: SpeculativeConfig):
35+
self.num_spec_tokens = (speculative_config.num_speculative_tokens
36+
if speculative_config is not None else 0)
37+
self.reset()
38+
39+
def reset(self):
40+
self.num_draft_tokens: list[int] = []
41+
self.num_accepted_tokens: list[int] = []
42+
self.num_emitted_tokens: list[int] = []
43+
44+
def observe(self, spec_decoding_stats: SpecDecodingStats):
45+
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
46+
self.num_accepted_tokens.append(
47+
spec_decoding_stats.num_accepted_tokens)
48+
self.num_emitted_tokens.append(spec_decoding_stats.num_emitted_tokens)
49+
50+
def log(self):
51+
num_draft_tokens = np.sum(self.num_draft_tokens)
52+
num_accepted_tokens = np.sum(self.num_accepted_tokens)
53+
num_emitted_tokens = np.sum(self.num_emitted_tokens)
54+
# FIXME: relies on num_draft_tokens % k == 0 assumption
55+
#max_num_emitted_tokens = get_max_num_emitted_tokens(
56+
# draft_tokens=num_draft_tokens, k=self.num_spec_tokens)
57+
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens
58+
if num_draft_tokens > 0 else float("nan"))
59+
#system_efficiency = (num_emitted_tokens / max_num_emitted_tokens
60+
# if max_num_emitted_tokens > 0 else float("nan"))
61+
system_efficiency = float("nan")
62+
logger.info(
63+
"Speculative metrics: "
64+
"Draft acceptance rate: %.3f, "
65+
"System efficiency: %.3f, "
66+
"Number of speculative tokens: %d, "
67+
"Number of accepted tokens: %d, "
68+
"Number of draft tokens: %d, "
69+
"Number of emitted tokens: %d.", draft_acceptance_rate,
70+
system_efficiency, self.num_spec_tokens, num_accepted_tokens,
71+
num_draft_tokens, num_emitted_tokens)
72+
self.reset()
73+
74+
75+
def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
76+
"""Calculate the number of emitted tokens, assuming all tokens accepted.
77+
78+
This is equal to the number of sequences that have been speculated on,
79+
times (speculation len + 1). The +1 comes from the bonus token.
80+
"""
81+
# Determine the number of sequences that have been speculated on. Since
82+
# the batch size can be variable, we divide by k.
83+
print(f"DRAFT TOKENS {draft_tokens} K {k}")
84+
# Cannot assume this - ngram proposer says "If there are less than k
85+
# tokens follow the match, we will return the maximum amount of tokens
86+
# until the end."
87+
assert draft_tokens % k == 0
88+
total_num_spec_seqs = draft_tokens // k
89+
90+
# A single sequence may emit k accepted tokens and one bonus token in
91+
# the best case.
92+
num_emitted_per_seq_if_all_accepted = k + 1
93+
94+
# The max num of emitted tokens is the number of speculated sequences
95+
# times the max emitted per seq.
96+
return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted

0 commit comments

Comments
 (0)