1
1
from itertools import cycle
2
- from typing import List , Optional , Tuple
2
+ from typing import List , Optional , Sequence , Tuple , Union
3
3
4
4
import pytest
5
5
6
6
from vllm import LLM , SamplingParams
7
7
from vllm .model_executor .utils import set_random_seed
8
+ from vllm .sequence import PromptLogprobs , SampleLogprobs
8
9
9
10
from ...conftest import cleanup
10
- from ...models .utils import check_logprobs_close , check_outputs_equal
11
+ from ...models .utils import (TokensTextLogprobs ,
12
+ TokensTextLogprobsPromptLogprobs ,
13
+ check_logprobs_close , check_outputs_equal )
11
14
from ...utils import RemoteOpenAIServer
12
15
13
16
PROMPTS = [
@@ -81,45 +84,77 @@ def get_output_from_llm_generator(
81
84
return tokens , token_ids , acceptance_rate
82
85
83
86
84
- def run_logprob_correctness_test (vllm_runner ,
85
- common_llm_kwargs ,
86
- per_test_common_llm_kwargs ,
87
- baseline_llm_kwargs ,
88
- test_llm_kwargs ,
89
- batch_size : int ,
90
- max_output_len : int ,
91
- seed : Optional [int ] = 0 ,
92
- temperature : float = 0.0 ,
93
- logprobs : int = 1 ):
94
- org_args = {
95
- ** common_llm_kwargs ,
96
- ** per_test_common_llm_kwargs ,
97
- ** baseline_llm_kwargs ,
98
- }
99
-
100
- sd_args = {
101
- ** common_llm_kwargs ,
102
- ** per_test_common_llm_kwargs ,
103
- ** test_llm_kwargs ,
104
- }
105
-
106
- prompts = [prompt for prompt , _ in zip (cycle (PROMPTS ), range (batch_size ))]
107
-
108
- sampling_params = SamplingParams (temperature = temperature ,
109
- max_tokens = max_output_len ,
110
- seed = seed ,
111
- logprobs = logprobs )
112
-
113
- with vllm_runner (** org_args ) as vllm_model :
114
- org_outputs = vllm_model .generate_w_logprobs (prompts , sampling_params )
115
-
116
- with vllm_runner (** sd_args ) as vllm_model :
117
- sd_outputs = vllm_model .generate_w_logprobs (prompts , sampling_params )
118
-
119
- check_logprobs_close (outputs_0_lst = org_outputs ,
120
- outputs_1_lst = sd_outputs ,
121
- name_0 = "org" ,
122
- name_1 = "sd" )
87
+ def check_logprobs_correctness (
88
+ spec_outputs : Sequence [Union [TokensTextLogprobs ,
89
+ TokensTextLogprobsPromptLogprobs ]],
90
+ baseline_outputs : Sequence [Union [TokensTextLogprobs ,
91
+ TokensTextLogprobsPromptLogprobs ]],
92
+ disable_logprobs : bool = False ,
93
+ ):
94
+ """Compare sampled and prompt logprobs between baseline and spec decoding
95
+ """
96
+ if not disable_logprobs :
97
+ return check_logprobs_close (
98
+ outputs_0_lst = baseline_outputs ,
99
+ outputs_1_lst = spec_outputs ,
100
+ name_0 = "org" ,
101
+ name_1 = "sd" ,
102
+ )
103
+
104
+ # Check correctness when disable_logprobs == True
105
+ for spec_output , baseline_output in zip (spec_outputs , baseline_outputs ):
106
+ # Check generated token logprobs.
107
+ spec_logprobs = spec_output [2 ]
108
+ baseline_logprobs = baseline_output [2 ]
109
+ _check_logprobs_when_output_disabled (spec_logprobs ,
110
+ baseline_logprobs ,
111
+ is_prompt_logprobs = False )
112
+
113
+ # Check prompt logprobs too, if they exist
114
+ if len (baseline_output ) == 4 :
115
+ assert len (spec_output ) == 4
116
+ spec_prompt_logprobs = spec_output [3 ]
117
+ baseline_prompt_logprobs = baseline_output [3 ]
118
+ _check_logprobs_when_output_disabled (spec_prompt_logprobs ,
119
+ baseline_prompt_logprobs ,
120
+ is_prompt_logprobs = True )
121
+
122
+
123
+ def _check_logprobs_when_output_disabled (
124
+ spec_logprobs : Union [Optional [PromptLogprobs ], SampleLogprobs ],
125
+ baseline_logprobs : Union [Optional [PromptLogprobs ], SampleLogprobs ],
126
+ is_prompt_logprobs : bool = False ,
127
+ ):
128
+ # Prompt logprobs are optional
129
+ if is_prompt_logprobs and baseline_logprobs is None :
130
+ assert spec_logprobs is None
131
+ return
132
+
133
+ assert spec_logprobs is not None
134
+ assert baseline_logprobs is not None
135
+ assert len (spec_logprobs ) == len (baseline_logprobs )
136
+
137
+ # For each generated position of the sequence.
138
+ for pos , (spec_pos_logprobs , baseline_pos_logprobs ) in enumerate (
139
+ zip (spec_logprobs , baseline_logprobs )):
140
+
141
+ # First prompt logprob is expected to be None
142
+ if is_prompt_logprobs and baseline_pos_logprobs is None :
143
+ assert spec_pos_logprobs is None
144
+ assert pos == 0
145
+ continue
146
+
147
+ assert spec_pos_logprobs is not None
148
+ assert baseline_pos_logprobs is not None
149
+
150
+ # When disabled, the 1 logprob is returned with dummy values for the
151
+ # score and rank, but the token id should match the baseline model
152
+ assert len (spec_pos_logprobs ) == 1
153
+ (spec_pos_logprob_token_id ,
154
+ spec_pos_logprob ) = next (iter (spec_pos_logprobs .items ()))
155
+ assert spec_pos_logprob .rank == - 1
156
+ assert spec_pos_logprob .logprob == 0.0
157
+ assert spec_pos_logprob_token_id in baseline_pos_logprobs
123
158
124
159
125
160
def run_equality_correctness_test (
@@ -135,7 +170,10 @@ def run_equality_correctness_test(
135
170
disable_seed : bool = False ,
136
171
ignore_eos : bool = True ,
137
172
ensure_all_accepted : bool = False ,
138
- expected_acceptance_rate : Optional [float ] = None ):
173
+ expected_acceptance_rate : Optional [float ] = None ,
174
+ logprobs : Optional [int ] = None ,
175
+ prompt_logprobs : Optional [int ] = None ,
176
+ disable_logprobs : bool = False ):
139
177
140
178
org_args = {
141
179
** common_llm_kwargs ,
@@ -157,10 +195,12 @@ def run_equality_correctness_test(
157
195
sampling_params = SamplingParams (temperature = temperature ,
158
196
max_tokens = max_output_len ,
159
197
seed = seed ,
160
- ignore_eos = ignore_eos )
198
+ ignore_eos = ignore_eos ,
199
+ logprobs = logprobs ,
200
+ prompt_logprobs = prompt_logprobs )
161
201
162
202
with vllm_runner (** org_args ) as vllm_model :
163
- org_outputs = vllm_model .generate (prompts , sampling_params )
203
+ org_outputs = vllm_model .generate_w_logprobs (prompts , sampling_params )
164
204
165
205
with vllm_runner (** sd_args ) as vllm_model :
166
206
if ensure_all_accepted or expected_acceptance_rate is not None :
@@ -169,7 +209,7 @@ def run_equality_correctness_test(
169
209
'prometheus' ]
170
210
stat_logger .local_interval = - 100
171
211
172
- sd_outputs = vllm_model .generate (prompts , sampling_params )
212
+ sd_outputs = vllm_model .generate_w_logprobs (prompts , sampling_params )
173
213
174
214
if ensure_all_accepted or expected_acceptance_rate is not None :
175
215
acceptance_rate = (stat_logger .metrics .
@@ -185,11 +225,16 @@ def run_equality_correctness_test(
185
225
if expected_acceptance_rate is not None :
186
226
assert acceptance_rate >= expected_acceptance_rate - 1e-2
187
227
188
- check_outputs_equal (outputs_0_lst = org_outputs ,
189
- outputs_1_lst = sd_outputs ,
228
+ # Only pass token entries, not the logprobs
229
+ check_outputs_equal (outputs_0_lst = [out [0 :2 ] for out in org_outputs ],
230
+ outputs_1_lst = [out [0 :2 ] for out in sd_outputs ],
190
231
name_0 = "org" ,
191
232
name_1 = "sd" )
192
233
234
+ # Check logprobs if requested
235
+ if logprobs is not None or prompt_logprobs is not None :
236
+ check_logprobs_correctness (sd_outputs , org_outputs , disable_logprobs )
237
+
193
238
194
239
def run_equality_correctness_test_tp (model ,
195
240
common_llm_kwargs ,
0 commit comments