Skip to content

Commit ba82462

Browse files
comaniacLeiWang1999
authored andcommitted
[Misc] Log spec decode metrics (vllm-project#6454)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 0e3ad09 commit ba82462

File tree

4 files changed

+137
-14
lines changed

4 files changed

+137
-14
lines changed

tests/metrics/test_metrics.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,55 @@ def test_engine_log_metrics_regression(
168168
assert_metrics(engine, disable_log_stats, len(example_prompts))
169169

170170

171+
@pytest.mark.parametrize("model", MODELS)
172+
@pytest.mark.parametrize("dtype", ["half"])
173+
@pytest.mark.parametrize("max_tokens", [10])
174+
def test_metric_spec_decode(
175+
vllm_runner,
176+
example_prompts,
177+
model: str,
178+
dtype: str,
179+
max_tokens: int,
180+
) -> None:
181+
k = 5
182+
183+
with vllm_runner(model,
184+
dtype=dtype,
185+
disable_log_stats=False,
186+
gpu_memory_utilization=0.4,
187+
speculative_model=model,
188+
num_speculative_tokens=k,
189+
use_v2_block_manager=True) as vllm_model:
190+
191+
# Force log interval to be 0 to catch all metrics.
192+
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
193+
stat_logger.local_interval = 0
194+
195+
# Note that the purpose of this test is to verify spec decode
196+
# metrics instead of functional correctness, so the expected values
197+
# are intended to be loose.
198+
metric_name_to_expected_fn = {
199+
"gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
200+
"gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
201+
"counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
202+
"counter_spec_decode_num_draft_tokens": lambda v: v == k,
203+
"counter_spec_decode_num_emitted_tokens":
204+
lambda v: 0 <= v <= k + 1,
205+
}
206+
207+
# Use one request to better inspect the metrics.
208+
prompts = example_prompts[:1]
209+
210+
_ = vllm_model.generate_greedy(prompts, max_tokens)
211+
for metric_name, is_expected in metric_name_to_expected_fn.items():
212+
metric_val = getattr(
213+
stat_logger.metrics,
214+
metric_name).labels(**stat_logger.labels)._value.get()
215+
assert is_expected(metric_val), (
216+
f"the value of metric {metric_name} ({metric_val}) "
217+
"does not meet expectation")
218+
219+
171220
def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
172221
num_requests: int) -> None:
173222
if disable_log_stats:

tests/spec_decode/e2e/conftest.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
162162
}
163163
test_name = request.node.name
164164

165+
model = kwargs["model"]
166+
draft_model = kwargs.get("speculative_model", None)
167+
same_draft_target_model = (draft_model is not None
168+
and draft_model == model)
169+
165170
def generator_inner():
166171

167172
wait_for_gpu_memory_to_clear(
@@ -177,6 +182,13 @@ def generator_inner():
177182

178183
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
179184
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
185+
186+
# Override logging interval to 0 for spec decode test run to
187+
# log all metrics in time.
188+
if (baseline_or_test == "test" and not use_async
189+
and llm.llm_engine.log_stats):
190+
for sate_logger in llm.llm_engine.stat_loggers.values():
191+
sate_logger.local_interval = 0
180192
set_random_seed(seed)
181193

182194
yield llm
@@ -188,6 +200,9 @@ def generator_outer():
188200
yield llm
189201
del llm
190202

203+
# Set an attribute to the generator_outer function to allow us to
204+
# determine whether to further check the acceptance rate in tests.
205+
generator_outer.same_draft_target_model = same_draft_target_model # type: ignore
191206
return generator_outer
192207

193208

@@ -204,18 +219,26 @@ def maybe_assert_ngram_worker(llm):
204219

205220
def get_output_from_llm_generator(
206221
llm_generator, prompts,
207-
sampling_params) -> Tuple[List[str], List[List[int]]]:
222+
sampling_params) -> Tuple[List[str], List[List[int]], float]:
208223
tokens: List[str] = []
209224
token_ids: List[List[int]] = []
225+
acceptance_rate: float = -1.0
210226
for llm in llm_generator():
211227
maybe_assert_ngram_worker(llm)
212228

213229
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
214230
token_ids = [output.outputs[0].token_ids for output in outputs]
215231
tokens = [output.outputs[0].text for output in outputs]
232+
233+
# Fetch acceptance rate if logging is enabled.
234+
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
235+
stat_logger = stat_loggers["prometheus"]
236+
acceptance_rate = (stat_logger.metrics.
237+
gauge_spec_decode_draft_acceptance_rate.labels(
238+
**stat_logger.labels)._value.get())
216239
del llm
217240

218-
return tokens, token_ids
241+
return tokens, token_ids, acceptance_rate
219242

220243

221244
def get_logprobs_from_llm_generator(
@@ -237,7 +260,8 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
237260
batch_size,
238261
max_output_len,
239262
force_output_len: bool,
240-
print_tokens: bool = False):
263+
print_tokens: bool = False,
264+
ensure_all_accepted: bool = False):
241265
"""Helper method that compares the outputs of both the baseline LLM and
242266
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
243267
the same when temperature is zero.
@@ -267,12 +291,13 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
267291
temperature=temperature,
268292
)
269293

270-
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
271-
test_llm_generator, prompts, sampling_params)
294+
(spec_batch_tokens, spec_batch_token_ids,
295+
acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
296+
prompts, sampling_params)
272297

273-
(baseline_batch_tokens,
274-
baseline_batch_token_ids) = get_output_from_llm_generator(
275-
baseline_llm_generator, prompts, sampling_params)
298+
(baseline_batch_tokens, baseline_batch_token_ids,
299+
_) = get_output_from_llm_generator(baseline_llm_generator, prompts,
300+
sampling_params)
276301

277302
assert len(baseline_batch_token_ids) == len(prompts)
278303
assert len(spec_batch_token_ids) == len(prompts)
@@ -287,3 +312,6 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
287312
print(f'{i=} {baseline_token_ids=}')
288313
print(f'{i=} {spec_token_ids=}')
289314
assert baseline_token_ids == spec_token_ids
315+
316+
if ensure_all_accepted:
317+
assert acceptance_rate == 1.0

tests/spec_decode/e2e/test_multistep_correctness.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
9797
temperature=temperature,
9898
)
9999

100-
batch_tokens, batch_token_ids = get_output_from_llm_generator(
100+
batch_tokens, batch_token_ids, _ = get_output_from_llm_generator(
101101
test_llm_generator, prompts, sampling_params)
102102

103103
# Expect a generation for each prompt in the batch.
@@ -200,12 +200,18 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
200200
201201
Since this test is cheaper than other e2e correctness tests, we generate
202202
with a higher output_len.
203+
204+
When the draft model is the same as the target model, we further check
205+
whether all speculative tokens are accepted.
203206
"""
204-
run_greedy_equality_correctness_test(baseline_llm_generator,
205-
test_llm_generator,
206-
batch_size,
207-
max_output_len=output_len,
208-
force_output_len=True)
207+
ensure_all_accepted = test_llm_generator.same_draft_target_model
208+
run_greedy_equality_correctness_test(
209+
baseline_llm_generator,
210+
test_llm_generator,
211+
batch_size,
212+
max_output_len=output_len,
213+
force_output_len=True,
214+
ensure_all_accepted=ensure_all_accepted)
209215

210216

211217
@pytest.mark.parametrize(

vllm/engine/metrics.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,30 @@ def __init__(self, labelnames: List[str], max_model_len: int):
133133
documentation="Count of successfully processed requests.",
134134
labelnames=labelnames + [Metrics.labelname_finish_reason])
135135

136+
# Speculatie decoding stats
137+
self.gauge_spec_decode_draft_acceptance_rate = self._base_library.Gauge(
138+
name="vllm:spec_decode_draft_acceptance_rate",
139+
documentation="Speulative token acceptance rate.",
140+
labelnames=labelnames)
141+
self.gauge_spec_decode_efficiency = self._base_library.Gauge(
142+
name="vllm:spec_decode_efficiency",
143+
documentation="Speculative decoding system efficiency.",
144+
labelnames=labelnames)
145+
self.counter_spec_decode_num_accepted_tokens = (
146+
self._base_library.Counter(
147+
name="vllm:spec_decode_num_accepted_tokens_total",
148+
documentation="Number of accepted tokens.",
149+
labelnames=labelnames))
150+
self.counter_spec_decode_num_draft_tokens = self._base_library.Counter(
151+
name="vllm:spec_decode_num_draft_tokens_total",
152+
documentation="Number of draft tokens.",
153+
labelnames=labelnames)
154+
self.counter_spec_decode_num_emitted_tokens = (
155+
self._base_library.Counter(
156+
name="vllm:spec_decode_num_emitted_tokens_total",
157+
documentation="Number of emitted tokens.",
158+
labelnames=labelnames))
159+
136160
# Deprecated in favor of vllm:prompt_tokens_total
137161
self.gauge_avg_prompt_throughput = self._base_library.Gauge(
138162
name="vllm:avg_prompt_throughput_toks_per_s",
@@ -454,6 +478,22 @@ def log(self, stats: Stats):
454478
self.num_generation_tokens = []
455479
self.last_local_log = stats.now
456480

481+
if stats.spec_decode_metrics is not None:
482+
self._log_gauge(
483+
self.metrics.gauge_spec_decode_draft_acceptance_rate,
484+
stats.spec_decode_metrics.draft_acceptance_rate)
485+
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
486+
stats.spec_decode_metrics.system_efficiency)
487+
self._log_counter(
488+
self.metrics.counter_spec_decode_num_accepted_tokens,
489+
stats.spec_decode_metrics.accepted_tokens)
490+
self._log_counter(
491+
self.metrics.counter_spec_decode_num_draft_tokens,
492+
stats.spec_decode_metrics.draft_tokens)
493+
self._log_counter(
494+
self.metrics.counter_spec_decode_num_emitted_tokens,
495+
stats.spec_decode_metrics.emitted_tokens)
496+
457497

458498
class RayPrometheusStatLogger(PrometheusStatLogger):
459499
"""RayPrometheusStatLogger uses Ray metrics instead."""

0 commit comments

Comments
 (0)