Skip to content

Commit 7ed6a4f

Browse files
[ BugFix ] Prompt Logprobs Detokenization (#6223)
Co-authored-by: Zifei Tong <zifeitong@gmail.com>
1 parent a4feba9 commit 7ed6a4f

File tree

4 files changed

+117
-32
lines changed

4 files changed

+117
-32
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ steps:
8787

8888
- label: Engine Test
8989
mirror_hardwares: [amd]
90-
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
90+
commands:
91+
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
92+
# OOM in the CI unless we run this separately
93+
- pytest -v -s tokenization
9194

9295
- label: Entrypoints Test
9396
mirror_hardwares: [amd]

tests/tokenization/test_detokenize.py

Lines changed: 87 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List
1+
from typing import Any, Dict, List, Optional
22

33
import pytest
44
from transformers import AutoTokenizer
@@ -139,6 +139,15 @@ def create_dummy_logprobs(
139139
} for token_id in complete_sequence_token_ids]
140140

141141

142+
def create_dummy_prompt_logprobs(
143+
complete_sequence_token_ids: List[int]
144+
) -> List[Optional[Dict[int, Any]]]:
145+
# logprob for the first prompt token is None.
146+
logprobs: List[Optional[Dict[int, Any]]] = [None]
147+
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
148+
return logprobs
149+
150+
142151
@pytest.mark.parametrize("complete_sequence", TRUTH)
143152
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
144153
@pytest.mark.parametrize("skip_special_tokens", [True, False])
@@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str,
177186

178187
@pytest.mark.parametrize("complete_sequence", TRUTH)
179188
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
180-
@pytest.mark.parametrize("skip_special_tokens", [True])
181-
def test_decode_prompt_logprobs(complete_sequence: str,
182-
complete_sequence_token_ids: List[int],
183-
detokenizer: Detokenizer,
184-
skip_special_tokens: bool):
189+
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
190+
detokenizer: Detokenizer):
185191
"""Verify Detokenizer decodes prompt logprobs correctly."""
186-
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
192+
sampling_params = SamplingParams(skip_special_tokens=True,
187193
prompt_logprobs=1)
188194

189195
# Run sequentially.
@@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str,
192198
seqs=[seq],
193199
sampling_params=sampling_params,
194200
arrival_time=0.0)
195-
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
196-
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
197-
decoded_prompt_logprobs = dummy_logprobs
201+
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
202+
detokenizer.decode_prompt_logprobs_inplace(seq_group,
203+
dummy_logprobs,
204+
position_offset=0)
205+
# First logprob is None.
206+
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
207+
1:] # type: ignore
198208

199-
if skip_special_tokens:
200-
# Text for logprobs for the chosen token should be the same as the
201-
# prompt text. Note that this will only be true if we skip
202-
# special tokens.
203-
assert complete_sequence == "".join([
204-
logprobs[token_id].decoded_token for token_id, logprobs in zip(
205-
complete_sequence_token_ids, decoded_prompt_logprobs)
206-
])
207-
assert complete_sequence != "".join([
208-
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip(
209-
complete_sequence_token_ids, decoded_prompt_logprobs)
210-
])
209+
# decoded_prompt_logprobs doesn't contain the first token.
210+
token_ids = complete_sequence_token_ids
211+
tokenzier = detokenizer.get_tokenizer_for_seq(seq)
212+
text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
213+
text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
214+
text = text_full[len(text_first):]
215+
216+
# Text for logprobs for the chosen token should be the same as the
217+
# prompt text. Note that the first logprob is None.
218+
assert text == "".join([
219+
logprobs[token_id].decoded_token
220+
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
221+
])
222+
assert text != "".join([
223+
logprobs[token_id + 1].decoded_token
224+
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
225+
])
226+
227+
228+
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
229+
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
230+
def test_decode_prompt_logprobs_chunked_prefill(
231+
vllm_runner,
232+
model,
233+
chunked_prefill_token_size: int,
234+
example_prompts,
235+
):
236+
max_num_seqs = 256
237+
enable_chunked_prefill = False
238+
max_num_batched_tokens = None
239+
if chunked_prefill_token_size != -1:
240+
enable_chunked_prefill = True
241+
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
242+
max_num_batched_tokens = chunked_prefill_token_size
243+
244+
with vllm_runner(model,
245+
dtype="half",
246+
max_logprobs=5,
247+
gpu_memory_utilization=0.5,
248+
enable_chunked_prefill=enable_chunked_prefill,
249+
max_num_batched_tokens=max_num_batched_tokens,
250+
max_num_seqs=max_num_seqs) as vllm_model:
251+
252+
vllm_sampling_params = SamplingParams(max_tokens=10,
253+
logprobs=5,
254+
prompt_logprobs=5,
255+
temperature=0.0)
256+
vllm_results = vllm_model.model.generate(
257+
example_prompts, sampling_params=vllm_sampling_params)
258+
259+
for idx, result in enumerate(vllm_results):
260+
assert result.prompt_logprobs is not None
261+
assert result.prompt_logprobs[0] is None
262+
263+
# Compared detokenized prompts ids to original prompt.
264+
generated_string = ""
265+
for (prompt_token,
266+
prompt_logprobs) in zip(result.prompt_token_ids[1:],
267+
result.prompt_logprobs[1:]):
268+
# prompt_logprobs is a dict of the token_id: logprob
269+
# We select the token_id corresponding to the actual prompt
270+
# Decoded token in the detokenized string corresponding to this
271+
# prompt token.
272+
generated_string += prompt_logprobs[prompt_token].decoded_token
273+
274+
assert generated_string == example_prompts[idx], (
275+
"Detokenized prompt logprobs do not match original prompt")

vllm/engine/output_processor/single_step.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,23 @@ def process_prompt_logprob(self, seq_group: SequenceGroup,
6060
assert len(outputs) == 1, ("Single step should only has 1 output.")
6161
output = outputs[0]
6262
prompt_logprobs = output.prompt_logprobs
63+
64+
# If this is the first (or only) "chunk" of the prefill, we need
65+
# to prepend None to the list of prompt logprobs. The reason for this
66+
# is that for N prompt tokens, the Sampler will generate N-1 total
67+
# prompt logprobs during prefill since the token at idx 0 will not
68+
# have a logprob associated with it.
6369
if prompt_logprobs is not None:
70+
if not seq_group.prompt_logprobs:
71+
prompt_logprobs = [None] + prompt_logprobs
72+
seq_group.prompt_logprobs = []
73+
6474
if seq_group.sampling_params.detokenize and self.detokenizer:
6575
self.detokenizer.decode_prompt_logprobs_inplace(
66-
seq_group, prompt_logprobs)
67-
if not seq_group.prompt_logprobs:
68-
# The first prompt token's logprob is None because it doesn't
69-
# have tokens that are precedent.
70-
seq_group.prompt_logprobs = [None]
76+
seq_group,
77+
prompt_logprobs,
78+
position_offset=len(seq_group.prompt_logprobs))
79+
7180
seq_group.prompt_logprobs.extend(prompt_logprobs)
7281

7382
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,

vllm/transformers_utils/detokenizer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@ def get_tokenizer_for_seq(self,
2121
"""Returns the HF tokenizer to use for a given sequence."""
2222
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
2323

24-
def decode_prompt_logprobs_inplace(
25-
self, seq_group: SequenceGroup,
26-
prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None:
24+
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
25+
prompt_logprobs: List[Optional[Dict[
26+
int, Logprob]]],
27+
position_offset: int) -> None:
2728
"""Decodes the logprobs for the prompt of a sequence group.
2829
2930
Args:
3031
seq_group: The sequence group to decode.
3132
prompt_logprobs: The logprobs to decode.
33+
position_offset: Offset of the first index of the logprobs
34+
relative to the start of the sequence (for chunked prefill).
3235
3336
Returns:
3437
The prompt logprobs with the decoded tokens.
@@ -47,8 +50,13 @@ def decode_prompt_logprobs_inplace(
4750
next_iter_tokens: List[str] = []
4851
prev_tokens = None
4952

50-
for token_position, prompt_logprobs_for_token in enumerate(
53+
for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
5154
prompt_logprobs):
55+
56+
# Absolute token position equals the index in the logprobs
57+
# list plus the offset of the entire logprobs list relative
58+
# to the start of the sequence.
59+
token_position = token_position_in_logprob + position_offset
5260
if not prompt_logprobs_for_token:
5361
continue
5462
for token_id, sample_logprob in prompt_logprobs_for_token.items():

0 commit comments

Comments
 (0)