1
- from typing import Dict , List
1
+ from typing import Any , Dict , List , Optional
2
2
3
3
import pytest
4
4
from transformers import AutoTokenizer
@@ -139,6 +139,15 @@ def create_dummy_logprobs(
139
139
} for token_id in complete_sequence_token_ids ]
140
140
141
141
142
+ def create_dummy_prompt_logprobs (
143
+ complete_sequence_token_ids : List [int ]
144
+ ) -> List [Optional [Dict [int , Any ]]]:
145
+ # logprob for the first prompt token is None.
146
+ logprobs : List [Optional [Dict [int , Any ]]] = [None ]
147
+ logprobs .extend (create_dummy_logprobs (complete_sequence_token_ids )[1 :])
148
+ return logprobs
149
+
150
+
142
151
@pytest .mark .parametrize ("complete_sequence" , TRUTH )
143
152
@pytest .mark .parametrize ("tokenizer_name" , TOKENIZERS )
144
153
@pytest .mark .parametrize ("skip_special_tokens" , [True , False ])
@@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str,
177
186
178
187
@pytest .mark .parametrize ("complete_sequence" , TRUTH )
179
188
@pytest .mark .parametrize ("tokenizer_name" , TOKENIZERS )
180
- @pytest .mark .parametrize ("skip_special_tokens" , [True ])
181
- def test_decode_prompt_logprobs (complete_sequence : str ,
182
- complete_sequence_token_ids : List [int ],
183
- detokenizer : Detokenizer ,
184
- skip_special_tokens : bool ):
189
+ def test_decode_prompt_logprobs (complete_sequence_token_ids : List [int ],
190
+ detokenizer : Detokenizer ):
185
191
"""Verify Detokenizer decodes prompt logprobs correctly."""
186
- sampling_params = SamplingParams (skip_special_tokens = skip_special_tokens ,
192
+ sampling_params = SamplingParams (skip_special_tokens = True ,
187
193
prompt_logprobs = 1 )
188
194
189
195
# Run sequentially.
@@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str,
192
198
seqs = [seq ],
193
199
sampling_params = sampling_params ,
194
200
arrival_time = 0.0 )
195
- dummy_logprobs = create_dummy_logprobs (complete_sequence_token_ids )
196
- detokenizer .decode_prompt_logprobs_inplace (seq_group , dummy_logprobs )
197
- decoded_prompt_logprobs = dummy_logprobs
201
+ dummy_logprobs = create_dummy_prompt_logprobs (complete_sequence_token_ids )
202
+ detokenizer .decode_prompt_logprobs_inplace (seq_group ,
203
+ dummy_logprobs ,
204
+ position_offset = 0 )
205
+ # First logprob is None.
206
+ decoded_prompt_logprobs : List [Dict [int , Any ]] = dummy_logprobs [
207
+ 1 :] # type: ignore
198
208
199
- if skip_special_tokens :
200
- # Text for logprobs for the chosen token should be the same as the
201
- # prompt text. Note that this will only be true if we skip
202
- # special tokens.
203
- assert complete_sequence == "" .join ([
204
- logprobs [token_id ].decoded_token for token_id , logprobs in zip (
205
- complete_sequence_token_ids , decoded_prompt_logprobs )
206
- ])
207
- assert complete_sequence != "" .join ([
208
- logprobs [token_id + 1 ].decoded_token for token_id , logprobs in zip (
209
- complete_sequence_token_ids , decoded_prompt_logprobs )
210
- ])
209
+ # decoded_prompt_logprobs doesn't contain the first token.
210
+ token_ids = complete_sequence_token_ids
211
+ tokenzier = detokenizer .get_tokenizer_for_seq (seq )
212
+ text_full = tokenzier .decode (token_ids , skip_special_tokens = True )
213
+ text_first = tokenzier .decode (token_ids [0 ], skip_special_tokens = True )
214
+ text = text_full [len (text_first ):]
215
+
216
+ # Text for logprobs for the chosen token should be the same as the
217
+ # prompt text. Note that the first logprob is None.
218
+ assert text == "" .join ([
219
+ logprobs [token_id ].decoded_token
220
+ for token_id , logprobs in zip (token_ids [1 :], decoded_prompt_logprobs )
221
+ ])
222
+ assert text != "" .join ([
223
+ logprobs [token_id + 1 ].decoded_token
224
+ for token_id , logprobs in zip (token_ids [1 :], decoded_prompt_logprobs )
225
+ ])
226
+
227
+
228
+ @pytest .mark .parametrize ("model" , ["facebook/opt-125m" ])
229
+ @pytest .mark .parametrize ("chunked_prefill_token_size" , [1 , 4 , 7 , 16 , - 1 ])
230
+ def test_decode_prompt_logprobs_chunked_prefill (
231
+ vllm_runner ,
232
+ model ,
233
+ chunked_prefill_token_size : int ,
234
+ example_prompts ,
235
+ ):
236
+ max_num_seqs = 256
237
+ enable_chunked_prefill = False
238
+ max_num_batched_tokens = None
239
+ if chunked_prefill_token_size != - 1 :
240
+ enable_chunked_prefill = True
241
+ max_num_seqs = min (chunked_prefill_token_size , max_num_seqs )
242
+ max_num_batched_tokens = chunked_prefill_token_size
243
+
244
+ with vllm_runner (model ,
245
+ dtype = "half" ,
246
+ max_logprobs = 5 ,
247
+ gpu_memory_utilization = 0.5 ,
248
+ enable_chunked_prefill = enable_chunked_prefill ,
249
+ max_num_batched_tokens = max_num_batched_tokens ,
250
+ max_num_seqs = max_num_seqs ) as vllm_model :
251
+
252
+ vllm_sampling_params = SamplingParams (max_tokens = 10 ,
253
+ logprobs = 5 ,
254
+ prompt_logprobs = 5 ,
255
+ temperature = 0.0 )
256
+ vllm_results = vllm_model .model .generate (
257
+ example_prompts , sampling_params = vllm_sampling_params )
258
+
259
+ for idx , result in enumerate (vllm_results ):
260
+ assert result .prompt_logprobs is not None
261
+ assert result .prompt_logprobs [0 ] is None
262
+
263
+ # Compared detokenized prompts ids to original prompt.
264
+ generated_string = ""
265
+ for (prompt_token ,
266
+ prompt_logprobs ) in zip (result .prompt_token_ids [1 :],
267
+ result .prompt_logprobs [1 :]):
268
+ # prompt_logprobs is a dict of the token_id: logprob
269
+ # We select the token_id corresponding to the actual prompt
270
+ # Decoded token in the detokenized string corresponding to this
271
+ # prompt token.
272
+ generated_string += prompt_logprobs [prompt_token ].decoded_token
273
+
274
+ assert generated_string == example_prompts [idx ], (
275
+ "Detokenized prompt logprobs do not match original prompt" )
0 commit comments