diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8013fbb642b..e09122ba61f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -87,7 +87,10 @@ steps: - label: Engine Test mirror_hardwares: [amd] - command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py + commands: + - pytest -v -s engine test_sequence.py test_config.py test_logger.py + # OOM in the CI unless we run this separately + - pytest -v -s tokenization - label: Entrypoints Test mirror_hardwares: [amd] diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 12e5ae85ade..f4551ed42ef 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Any, Dict, List, Optional import pytest from transformers import AutoTokenizer @@ -139,6 +139,15 @@ def create_dummy_logprobs( } for token_id in complete_sequence_token_ids] +def create_dummy_prompt_logprobs( + complete_sequence_token_ids: List[int] +) -> List[Optional[Dict[int, Any]]]: + # logprob for the first prompt token is None. + logprobs: List[Optional[Dict[int, Any]]] = [None] + logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:]) + return logprobs + + @pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("skip_special_tokens", [True, False]) @@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str, @pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -@pytest.mark.parametrize("skip_special_tokens", [True]) -def test_decode_prompt_logprobs(complete_sequence: str, - complete_sequence_token_ids: List[int], - detokenizer: Detokenizer, - skip_special_tokens: bool): +def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], + detokenizer: Detokenizer): """Verify Detokenizer decodes prompt logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, + sampling_params = SamplingParams(skip_special_tokens=True, prompt_logprobs=1) # Run sequentially. @@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str, seqs=[seq], sampling_params=sampling_params, arrival_time=0.0) - dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) - detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs) - decoded_prompt_logprobs = dummy_logprobs + dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids) + detokenizer.decode_prompt_logprobs_inplace(seq_group, + dummy_logprobs, + position_offset=0) + # First logprob is None. + decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[ + 1:] # type: ignore - if skip_special_tokens: - # Text for logprobs for the chosen token should be the same as the - # prompt text. Note that this will only be true if we skip - # special tokens. - assert complete_sequence == "".join([ - logprobs[token_id].decoded_token for token_id, logprobs in zip( - complete_sequence_token_ids, decoded_prompt_logprobs) - ]) - assert complete_sequence != "".join([ - logprobs[token_id + 1].decoded_token for token_id, logprobs in zip( - complete_sequence_token_ids, decoded_prompt_logprobs) - ]) + # decoded_prompt_logprobs doesn't contain the first token. + token_ids = complete_sequence_token_ids + tokenzier = detokenizer.get_tokenizer_for_seq(seq) + text_full = tokenzier.decode(token_ids, skip_special_tokens=True) + text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True) + text = text_full[len(text_first):] + + # Text for logprobs for the chosen token should be the same as the + # prompt text. Note that the first logprob is None. + assert text == "".join([ + logprobs[token_id].decoded_token + for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) + ]) + assert text != "".join([ + logprobs[token_id + 1].decoded_token + for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) + ]) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1]) +def test_decode_prompt_logprobs_chunked_prefill( + vllm_runner, + model, + chunked_prefill_token_size: int, + example_prompts, +): + max_num_seqs = 256 + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) + max_num_batched_tokens = chunked_prefill_token_size + + with vllm_runner(model, + dtype="half", + max_logprobs=5, + gpu_memory_utilization=0.5, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs) as vllm_model: + + vllm_sampling_params = SamplingParams(max_tokens=10, + logprobs=5, + prompt_logprobs=5, + temperature=0.0) + vllm_results = vllm_model.model.generate( + example_prompts, sampling_params=vllm_sampling_params) + + for idx, result in enumerate(vllm_results): + assert result.prompt_logprobs is not None + assert result.prompt_logprobs[0] is None + + # Compared detokenized prompts ids to original prompt. + generated_string = "" + for (prompt_token, + prompt_logprobs) in zip(result.prompt_token_ids[1:], + result.prompt_logprobs[1:]): + # prompt_logprobs is a dict of the token_id: logprob + # We select the token_id corresponding to the actual prompt + # Decoded token in the detokenized string corresponding to this + # prompt token. + generated_string += prompt_logprobs[prompt_token].decoded_token + + assert generated_string == example_prompts[idx], ( + "Detokenized prompt logprobs do not match original prompt") diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index fa672e1feda..4851897ddef 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -60,14 +60,23 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, assert len(outputs) == 1, ("Single step should only has 1 output.") output = outputs[0] prompt_logprobs = output.prompt_logprobs + + # If this is the first (or only) "chunk" of the prefill, we need + # to prepend None to the list of prompt logprobs. The reason for this + # is that for N prompt tokens, the Sampler will generate N-1 total + # prompt logprobs during prefill since the token at idx 0 will not + # have a logprob associated with it. if prompt_logprobs is not None: + if not seq_group.prompt_logprobs: + prompt_logprobs = [None] + prompt_logprobs + seq_group.prompt_logprobs = [] + if seq_group.sampling_params.detokenize and self.detokenizer: self.detokenizer.decode_prompt_logprobs_inplace( - seq_group, prompt_logprobs) - if not seq_group.prompt_logprobs: - # The first prompt token's logprob is None because it doesn't - # have tokens that are precedent. - seq_group.prompt_logprobs = [None] + seq_group, + prompt_logprobs, + position_offset=len(seq_group.prompt_logprobs)) + seq_group.prompt_logprobs.extend(prompt_logprobs) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index e8e53f4946e..cc9a971301a 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -21,14 +21,17 @@ def get_tokenizer_for_seq(self, """Returns the HF tokenizer to use for a given sequence.""" return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request) - def decode_prompt_logprobs_inplace( - self, seq_group: SequenceGroup, - prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None: + def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, + prompt_logprobs: List[Optional[Dict[ + int, Logprob]]], + position_offset: int) -> None: """Decodes the logprobs for the prompt of a sequence group. Args: seq_group: The sequence group to decode. prompt_logprobs: The logprobs to decode. + position_offset: Offset of the first index of the logprobs + relative to the start of the sequence (for chunked prefill). Returns: The prompt logprobs with the decoded tokens. @@ -47,8 +50,13 @@ def decode_prompt_logprobs_inplace( next_iter_tokens: List[str] = [] prev_tokens = None - for token_position, prompt_logprobs_for_token in enumerate( + for token_position_in_logprob, prompt_logprobs_for_token in enumerate( prompt_logprobs): + + # Absolute token position equals the index in the logprobs + # list plus the offset of the entire logprobs list relative + # to the start of the sequence. + token_position = token_position_in_logprob + position_offset if not prompt_logprobs_for_token: continue for token_id, sample_logprob in prompt_logprobs_for_token.items():