Skip to content

[Misc] Log spec decode metrics #6454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,55 @@ def test_engine_log_metrics_regression(
assert_metrics(engine, disable_log_stats, len(example_prompts))


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [10])
def test_metric_spec_decode(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
k = 5

with vllm_runner(model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4,
speculative_model=model,
num_speculative_tokens=k,
use_v2_block_manager=True) as vllm_model:

# Force log interval to be 0 to catch all metrics.
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
stat_logger.local_interval = 0

# Note that the purpose of this test is to verify spec decode
# metrics instead of functional correctness, so the expected values
# are intended to be loose.
metric_name_to_expected_fn = {
"gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
"gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
"counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
"counter_spec_decode_num_draft_tokens": lambda v: v == k,
"counter_spec_decode_num_emitted_tokens":
lambda v: 0 <= v <= k + 1,
}

# Use one request to better inspect the metrics.
prompts = example_prompts[:1]

_ = vllm_model.generate_greedy(prompts, max_tokens)
for metric_name, is_expected in metric_name_to_expected_fn.items():
metric_val = getattr(
stat_logger.metrics,
metric_name).labels(**stat_logger.labels)._value.get()
assert is_expected(metric_val), (
f"the value of metric {metric_name} ({metric_val}) "
"does not meet expectation")


def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
num_requests: int) -> None:
if disable_log_stats:
Expand Down
44 changes: 36 additions & 8 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
}
test_name = request.node.name

model = kwargs["model"]
draft_model = kwargs.get("speculative_model", None)
same_draft_target_model = (draft_model is not None
and draft_model == model)

def generator_inner():

wait_for_gpu_memory_to_clear(
Expand All @@ -177,6 +182,13 @@ def generator_inner():

print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)

# Override logging interval to 0 for spec decode test run to
# log all metrics in time.
if (baseline_or_test == "test" and not use_async
and llm.llm_engine.log_stats):
for sate_logger in llm.llm_engine.stat_loggers.values():
sate_logger.local_interval = 0
set_random_seed(seed)

yield llm
Expand All @@ -188,6 +200,9 @@ def generator_outer():
yield llm
del llm

# Set an attribute to the generator_outer function to allow us to
# determine whether to further check the acceptance rate in tests.
generator_outer.same_draft_target_model = same_draft_target_model # type: ignore
return generator_outer


Expand All @@ -204,18 +219,26 @@ def maybe_assert_ngram_worker(llm):

def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
sampling_params) -> Tuple[List[str], List[List[int]], float]:
tokens: List[str] = []
token_ids: List[List[int]] = []
acceptance_rate: float = -1.0
for llm in llm_generator():
maybe_assert_ngram_worker(llm)

outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]

# Fetch acceptance rate if logging is enabled.
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
stat_logger = stat_loggers["prometheus"]
acceptance_rate = (stat_logger.metrics.
gauge_spec_decode_draft_acceptance_rate.labels(
**stat_logger.labels)._value.get())
del llm

return tokens, token_ids
return tokens, token_ids, acceptance_rate


def get_logprobs_from_llm_generator(
Expand All @@ -237,7 +260,8 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
print_tokens: bool = False,
ensure_all_accepted: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
Expand Down Expand Up @@ -267,12 +291,13 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
temperature=temperature,
)

spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
(spec_batch_tokens, spec_batch_token_ids,
acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
prompts, sampling_params)

(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
(baseline_batch_tokens, baseline_batch_token_ids,
_) = get_output_from_llm_generator(baseline_llm_generator, prompts,
sampling_params)

assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)
Expand All @@ -287,3 +312,6 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids

if ensure_all_accepted:
assert acceptance_rate == 1.0
18 changes: 12 additions & 6 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
temperature=temperature,
)

batch_tokens, batch_token_ids = get_output_from_llm_generator(
batch_tokens, batch_token_ids, _ = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)

# Expect a generation for each prompt in the batch.
Expand Down Expand Up @@ -200,12 +200,18 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(

Since this test is cheaper than other e2e correctness tests, we generate
with a higher output_len.

When the draft model is the same as the target model, we further check
whether all speculative tokens are accepted.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
ensure_all_accepted = test_llm_generator.same_draft_target_model
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
ensure_all_accepted=ensure_all_accepted)


@pytest.mark.parametrize(
Expand Down
40 changes: 40 additions & 0 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,30 @@ def __init__(self, labelnames: List[str], max_model_len: int):
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])

# Speculatie decoding stats
self.gauge_spec_decode_draft_acceptance_rate = self._base_library.Gauge(
name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.",
labelnames=labelnames)
self.gauge_spec_decode_efficiency = self._base_library.Gauge(
name="vllm:spec_decode_efficiency",
documentation="Speculative decoding system efficiency.",
labelnames=labelnames)
self.counter_spec_decode_num_accepted_tokens = (
self._base_library.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames))
self.counter_spec_decode_num_draft_tokens = self._base_library.Counter(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames)
self.counter_spec_decode_num_emitted_tokens = (
self._base_library.Counter(
name="vllm:spec_decode_num_emitted_tokens_total",
documentation="Number of emitted tokens.",
labelnames=labelnames))

# Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = self._base_library.Gauge(
name="vllm:avg_prompt_throughput_toks_per_s",
Expand Down Expand Up @@ -454,6 +478,22 @@ def log(self, stats: Stats):
self.num_generation_tokens = []
self.last_local_log = stats.now

if stats.spec_decode_metrics is not None:
self._log_gauge(
self.metrics.gauge_spec_decode_draft_acceptance_rate,
stats.spec_decode_metrics.draft_acceptance_rate)
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
stats.spec_decode_metrics.system_efficiency)
self._log_counter(
self.metrics.counter_spec_decode_num_accepted_tokens,
stats.spec_decode_metrics.accepted_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_draft_tokens,
stats.spec_decode_metrics.draft_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_emitted_tokens,
stats.spec_decode_metrics.emitted_tokens)


class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
Expand Down
Loading