Skip to content

Commit 01b6f9e

Browse files
[Core][Bugfix] Support prompt_logprobs returned with speculative decoding (#8047)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
1 parent 13f9f7a commit 01b6f9e

File tree

14 files changed

+492
-134
lines changed

14 files changed

+492
-134
lines changed

tests/conftest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -675,8 +675,6 @@ def generate_w_logprobs(
675675
videos: Optional[PromptVideoInput] = None,
676676
) -> Union[List[TokensTextLogprobs],
677677
List[TokensTextLogprobsPromptLogprobs]]:
678-
assert sampling_params.logprobs is not None
679-
680678
if images is not None:
681679
assert len(prompts) == len(images)
682680

@@ -754,7 +752,7 @@ def generate_greedy_logprobs(
754752
temperature=0.0,
755753
max_tokens=max_tokens,
756754
logprobs=num_logprobs,
757-
prompt_logprobs=(num_prompt_logprobs),
755+
prompt_logprobs=num_prompt_logprobs,
758756
stop_token_ids=stop_token_ids)
759757

760758
return self.generate_w_logprobs(prompts,

tests/spec_decode/e2e/conftest.py

Lines changed: 92 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from itertools import cycle
2-
from typing import List, Optional, Tuple
2+
from typing import List, Optional, Sequence, Tuple, Union
33

44
import pytest
55

66
from vllm import LLM, SamplingParams
77
from vllm.model_executor.utils import set_random_seed
8+
from vllm.sequence import PromptLogprobs, SampleLogprobs
89

910
from ...conftest import cleanup
10-
from ...models.utils import check_logprobs_close, check_outputs_equal
11+
from ...models.utils import (TokensTextLogprobs,
12+
TokensTextLogprobsPromptLogprobs,
13+
check_logprobs_close, check_outputs_equal)
1114
from ...utils import RemoteOpenAIServer
1215

1316
PROMPTS = [
@@ -81,45 +84,77 @@ def get_output_from_llm_generator(
8184
return tokens, token_ids, acceptance_rate
8285

8386

84-
def run_logprob_correctness_test(vllm_runner,
85-
common_llm_kwargs,
86-
per_test_common_llm_kwargs,
87-
baseline_llm_kwargs,
88-
test_llm_kwargs,
89-
batch_size: int,
90-
max_output_len: int,
91-
seed: Optional[int] = 0,
92-
temperature: float = 0.0,
93-
logprobs: int = 1):
94-
org_args = {
95-
**common_llm_kwargs,
96-
**per_test_common_llm_kwargs,
97-
**baseline_llm_kwargs,
98-
}
99-
100-
sd_args = {
101-
**common_llm_kwargs,
102-
**per_test_common_llm_kwargs,
103-
**test_llm_kwargs,
104-
}
105-
106-
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
107-
108-
sampling_params = SamplingParams(temperature=temperature,
109-
max_tokens=max_output_len,
110-
seed=seed,
111-
logprobs=logprobs)
112-
113-
with vllm_runner(**org_args) as vllm_model:
114-
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
115-
116-
with vllm_runner(**sd_args) as vllm_model:
117-
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
118-
119-
check_logprobs_close(outputs_0_lst=org_outputs,
120-
outputs_1_lst=sd_outputs,
121-
name_0="org",
122-
name_1="sd")
87+
def check_logprobs_correctness(
88+
spec_outputs: Sequence[Union[TokensTextLogprobs,
89+
TokensTextLogprobsPromptLogprobs]],
90+
baseline_outputs: Sequence[Union[TokensTextLogprobs,
91+
TokensTextLogprobsPromptLogprobs]],
92+
disable_logprobs: bool = False,
93+
):
94+
"""Compare sampled and prompt logprobs between baseline and spec decoding
95+
"""
96+
if not disable_logprobs:
97+
return check_logprobs_close(
98+
outputs_0_lst=baseline_outputs,
99+
outputs_1_lst=spec_outputs,
100+
name_0="org",
101+
name_1="sd",
102+
)
103+
104+
# Check correctness when disable_logprobs == True
105+
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
106+
# Check generated token logprobs.
107+
spec_logprobs = spec_output[2]
108+
baseline_logprobs = baseline_output[2]
109+
_check_logprobs_when_output_disabled(spec_logprobs,
110+
baseline_logprobs,
111+
is_prompt_logprobs=False)
112+
113+
# Check prompt logprobs too, if they exist
114+
if len(baseline_output) == 4:
115+
assert len(spec_output) == 4
116+
spec_prompt_logprobs = spec_output[3]
117+
baseline_prompt_logprobs = baseline_output[3]
118+
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
119+
baseline_prompt_logprobs,
120+
is_prompt_logprobs=True)
121+
122+
123+
def _check_logprobs_when_output_disabled(
124+
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
125+
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
126+
is_prompt_logprobs: bool = False,
127+
):
128+
# Prompt logprobs are optional
129+
if is_prompt_logprobs and baseline_logprobs is None:
130+
assert spec_logprobs is None
131+
return
132+
133+
assert spec_logprobs is not None
134+
assert baseline_logprobs is not None
135+
assert len(spec_logprobs) == len(baseline_logprobs)
136+
137+
# For each generated position of the sequence.
138+
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
139+
zip(spec_logprobs, baseline_logprobs)):
140+
141+
# First prompt logprob is expected to be None
142+
if is_prompt_logprobs and baseline_pos_logprobs is None:
143+
assert spec_pos_logprobs is None
144+
assert pos == 0
145+
continue
146+
147+
assert spec_pos_logprobs is not None
148+
assert baseline_pos_logprobs is not None
149+
150+
# When disabled, the 1 logprob is returned with dummy values for the
151+
# score and rank, but the token id should match the baseline model
152+
assert len(spec_pos_logprobs) == 1
153+
(spec_pos_logprob_token_id,
154+
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
155+
assert spec_pos_logprob.rank == -1
156+
assert spec_pos_logprob.logprob == 0.0
157+
assert spec_pos_logprob_token_id in baseline_pos_logprobs
123158

124159

125160
def run_equality_correctness_test(
@@ -135,7 +170,10 @@ def run_equality_correctness_test(
135170
disable_seed: bool = False,
136171
ignore_eos: bool = True,
137172
ensure_all_accepted: bool = False,
138-
expected_acceptance_rate: Optional[float] = None):
173+
expected_acceptance_rate: Optional[float] = None,
174+
logprobs: Optional[int] = None,
175+
prompt_logprobs: Optional[int] = None,
176+
disable_logprobs: bool = False):
139177

140178
org_args = {
141179
**common_llm_kwargs,
@@ -157,10 +195,12 @@ def run_equality_correctness_test(
157195
sampling_params = SamplingParams(temperature=temperature,
158196
max_tokens=max_output_len,
159197
seed=seed,
160-
ignore_eos=ignore_eos)
198+
ignore_eos=ignore_eos,
199+
logprobs=logprobs,
200+
prompt_logprobs=prompt_logprobs)
161201

162202
with vllm_runner(**org_args) as vllm_model:
163-
org_outputs = vllm_model.generate(prompts, sampling_params)
203+
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
164204

165205
with vllm_runner(**sd_args) as vllm_model:
166206
if ensure_all_accepted or expected_acceptance_rate is not None:
@@ -169,7 +209,7 @@ def run_equality_correctness_test(
169209
'prometheus']
170210
stat_logger.local_interval = -100
171211

172-
sd_outputs = vllm_model.generate(prompts, sampling_params)
212+
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
173213

174214
if ensure_all_accepted or expected_acceptance_rate is not None:
175215
acceptance_rate = (stat_logger.metrics.
@@ -185,11 +225,16 @@ def run_equality_correctness_test(
185225
if expected_acceptance_rate is not None:
186226
assert acceptance_rate >= expected_acceptance_rate - 1e-2
187227

188-
check_outputs_equal(outputs_0_lst=org_outputs,
189-
outputs_1_lst=sd_outputs,
228+
# Only pass token entries, not the logprobs
229+
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
230+
outputs_1_lst=[out[0:2] for out in sd_outputs],
190231
name_0="org",
191232
name_1="sd")
192233

234+
# Check logprobs if requested
235+
if logprobs is not None or prompt_logprobs is not None:
236+
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)
237+
193238

194239
def run_equality_correctness_test_tp(model,
195240
common_llm_kwargs,

tests/spec_decode/e2e/test_eagle_correctness.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,64 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
8080
batch_size, output_len, seed)
8181

8282

83+
@pytest.mark.parametrize(
84+
"common_llm_kwargs",
85+
[{
86+
# Skip cuda graph recording for fast test.
87+
"enforce_eager": True,
88+
89+
# Required for spec decode.
90+
"use_v2_block_manager": True,
91+
92+
# Print spec metrics.
93+
"disable_log_stats": False,
94+
95+
# Precision
96+
"dtype": PRECISION,
97+
98+
# Main model
99+
"model_name": MAIN_MODEL,
100+
}])
101+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
102+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
103+
@pytest.mark.parametrize("test_llm_kwargs", [
104+
{
105+
"speculative_model": SPEC_MODEL,
106+
"num_speculative_tokens": MAX_SPEC_TOKENS,
107+
"disable_logprobs_during_spec_decoding": False,
108+
},
109+
{
110+
"speculative_model": SPEC_MODEL,
111+
"num_speculative_tokens": MAX_SPEC_TOKENS,
112+
"disable_logprobs_during_spec_decoding": True,
113+
},
114+
])
115+
@pytest.mark.parametrize("output_len", [
116+
128,
117+
])
118+
@pytest.mark.parametrize("batch_size", [8])
119+
@pytest.mark.parametrize("seed", [1])
120+
@pytest.mark.parametrize("logprobs", [1, 6])
121+
def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
122+
per_test_common_llm_kwargs,
123+
baseline_llm_kwargs, test_llm_kwargs,
124+
batch_size: int, output_len: int, seed: int,
125+
logprobs: int):
126+
127+
run_equality_correctness_test(vllm_runner,
128+
common_llm_kwargs,
129+
per_test_common_llm_kwargs,
130+
baseline_llm_kwargs,
131+
test_llm_kwargs,
132+
batch_size,
133+
output_len,
134+
seed,
135+
logprobs=logprobs,
136+
prompt_logprobs=logprobs,
137+
disable_logprobs=test_llm_kwargs[
138+
'disable_logprobs_during_spec_decoding'])
139+
140+
83141
@pytest.mark.parametrize(
84142
"common_llm_kwargs",
85143
[{

0 commit comments

Comments
 (0)