From 08ce232dd983c78cf36e8252aec6cdd61c9c1fcb Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Mon, 8 Jul 2024 21:31:53 +0000 Subject: [PATCH 01/21] fix issue with logprobs --- vllm/engine/output_processor/single_step.py | 7 +++---- vllm/model_executor/layers/sampler.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index fa672e1feda..a16d1bd85d6 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -65,10 +65,9 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, self.detokenizer.decode_prompt_logprobs_inplace( seq_group, prompt_logprobs) if not seq_group.prompt_logprobs: - # The first prompt token's logprob is None because it doesn't - # have tokens that are precedent. - seq_group.prompt_logprobs = [None] - seq_group.prompt_logprobs.extend(prompt_logprobs) + seq_group.prompt_logprobs = prompt_logprobs + else: + seq_group.prompt_logprobs.extend(prompt_logprobs) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6d00ea64f7c..df8ebce13f9 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -836,7 +836,7 @@ def _get_prompt_logprob_if_needed( # Find prompt logprobs prompt_logprobs: Optional[PromptLogprobs] = None if is_prompt and sampling_params.prompt_logprobs is not None: - prompt_logprobs = [] + prompt_logprobs = [None] num_logprobs = sampling_params.prompt_logprobs next_prompt_tokens = _get_next_prompt_tokens(seq_group) # Pre-select indexes and create a list. It is faster than calling .item From 99e9cd854e22e97b80b48d8a06e25e886c061c33 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 01:46:57 +0000 Subject: [PATCH 02/21] reimplement in a way that supports chunked pre --- tests/samplers/test_logprobs.py | 6 ++++-- vllm/engine/output_processor/single_step.py | 16 ++++++++++++---- vllm/model_executor/layers/sampler.py | 4 ++-- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 02a953da046..64b44028f8a 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -12,7 +12,8 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) +# @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) @pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size @pytest.mark.parametrize("detokenize", [True, False]) def test_get_prompt_logprobs( @@ -80,10 +81,11 @@ def test_get_prompt_logprobs( assert output_text == '' assert output_string_from_most_likely_tokens_lst == ([None] * max_tokens) - + # The first prompt logprob is always None assert result.prompt_logprobs[0] is None for prompt_logprobs in result.prompt_logprobs[1:]: + # If the prompt token is not included in the top X # logprob, it can return 1 more data assert (len(prompt_logprobs) == num_top_logprobs diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index a16d1bd85d6..23a6d08aa2b 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -60,14 +60,22 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, assert len(outputs) == 1, ("Single step should only has 1 output.") output = outputs[0] prompt_logprobs = output.prompt_logprobs + + # If this is the first (or only) "chunk" of the prefill, we need + # to prepend None to the list of prompt logprobs. The reason for this + # is that for N prompt tokens, the Sampler will generate N-1 total + # prompt logprobs during prefill since the token at idx 0 will not + # have a logprob associated with it. if prompt_logprobs is not None: + if not seq_group.prompt_logprobs: + prompt_logprobs = [None] + prompt_logprobs + seq_group.prompt_logprobs = [] + if seq_group.sampling_params.detokenize and self.detokenizer: self.detokenizer.decode_prompt_logprobs_inplace( seq_group, prompt_logprobs) - if not seq_group.prompt_logprobs: - seq_group.prompt_logprobs = prompt_logprobs - else: - seq_group.prompt_logprobs.extend(prompt_logprobs) + + seq_group.prompt_logprobs.extend(prompt_logprobs) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index df8ebce13f9..c0d09c21378 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -835,8 +835,8 @@ def _get_prompt_logprob_if_needed( # Find prompt logprobs prompt_logprobs: Optional[PromptLogprobs] = None - if is_prompt and sampling_params.prompt_logprobs is not None: - prompt_logprobs = [None] + if is_prompt and sampling_params.prompt_logprobs is not None: + prompt_logprobs = [] num_logprobs = sampling_params.prompt_logprobs next_prompt_tokens = _get_next_prompt_tokens(seq_group) # Pre-select indexes and create a list. It is faster than calling .item From 01088a0ce70e6e0ac5eeb1c61b42e90d3df32f77 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 01:47:36 +0000 Subject: [PATCH 03/21] removed nits --- tests/samplers/test_logprobs.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 64b44028f8a..02a953da046 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -12,8 +12,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) -# @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) -@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) @pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size @pytest.mark.parametrize("detokenize", [True, False]) def test_get_prompt_logprobs( @@ -81,11 +80,10 @@ def test_get_prompt_logprobs( assert output_text == '' assert output_string_from_most_likely_tokens_lst == ([None] * max_tokens) - + # The first prompt logprob is always None assert result.prompt_logprobs[0] is None for prompt_logprobs in result.prompt_logprobs[1:]: - # If the prompt token is not included in the top X # logprob, it can return 1 more data assert (len(prompt_logprobs) == num_top_logprobs From a7a04aca5b0c594ef7963cb6c7f6b731fe0daa0c Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 01:53:44 +0000 Subject: [PATCH 04/21] nits --- vllm/engine/output_processor/single_step.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 23a6d08aa2b..4d030943474 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -74,7 +74,6 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, if seq_group.sampling_params.detokenize and self.detokenizer: self.detokenizer.decode_prompt_logprobs_inplace( seq_group, prompt_logprobs) - seq_group.prompt_logprobs.extend(prompt_logprobs) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, From a20be2384cbdb46b3f6f41c6c7109c35a24d7879 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 01:54:08 +0000 Subject: [PATCH 05/21] nit --- vllm/model_executor/layers/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c0d09c21378..6d00ea64f7c 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -835,7 +835,7 @@ def _get_prompt_logprob_if_needed( # Find prompt logprobs prompt_logprobs: Optional[PromptLogprobs] = None - if is_prompt and sampling_params.prompt_logprobs is not None: + if is_prompt and sampling_params.prompt_logprobs is not None: prompt_logprobs = [] num_logprobs = sampling_params.prompt_logprobs next_prompt_tokens = _get_next_prompt_tokens(seq_group) From bc40a992028a13e3505918b13665c0c86c483ae2 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 02:25:42 +0000 Subject: [PATCH 06/21] this test is almost passing --- tests/tokenization/test_detokenize.py | 82 ++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 9 deletions(-) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 12e5ae85ade..5f6b8ac644e 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -139,6 +139,13 @@ def create_dummy_logprobs( } for token_id in complete_sequence_token_ids] + +def create_dummy_prompt_logprobs( + complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]: + # logprob for the first prompt token is not defined. + return create_dummy_logprobs(complete_sequence_token_ids)[1:] + + @pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("skip_special_tokens", [True, False]) @@ -178,8 +185,7 @@ def test_decode_sequence_logprobs(complete_sequence: str, @pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) @pytest.mark.parametrize("skip_special_tokens", [True]) -def test_decode_prompt_logprobs(complete_sequence: str, - complete_sequence_token_ids: List[int], +def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], detokenizer: Detokenizer, skip_special_tokens: bool): """Verify Detokenizer decodes prompt logprobs correctly.""" @@ -192,19 +198,77 @@ def test_decode_prompt_logprobs(complete_sequence: str, seqs=[seq], sampling_params=sampling_params, arrival_time=0.0) - dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) + dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids) detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs) decoded_prompt_logprobs = dummy_logprobs if skip_special_tokens: + # decoded_prompt_logprobs doesn't contain the first token. + token_ids = complete_sequence_token_ids[1:] + tokenzier = detokenizer.get_tokenizer_for_seq(seq) + text = tokenzier.decode(token_ids, + skip_special_tokens=skip_special_tokens) # Text for logprobs for the chosen token should be the same as the # prompt text. Note that this will only be true if we skip # special tokens. - assert complete_sequence == "".join([ - logprobs[token_id].decoded_token for token_id, logprobs in zip( - complete_sequence_token_ids, decoded_prompt_logprobs) + assert text == "".join([ + logprobs[token_id].decoded_token + for token_id, logprobs in zip(token_ids, decoded_prompt_logprobs) ]) - assert complete_sequence != "".join([ - logprobs[token_id + 1].decoded_token for token_id, logprobs in zip( - complete_sequence_token_ids, decoded_prompt_logprobs) + assert text != "".join([ + logprobs[token_id + 1].decoded_token + for token_id, logprobs in zip(token_ids, decoded_prompt_logprobs) ]) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +# @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_decode_logprobs_regression( + vllm_runner, + model, + chunked_prefill_token_size: int, + example_prompts, +): + max_num_seqs = 256 + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) + max_num_batched_tokens = chunked_prefill_token_size + + with vllm_runner( + model, + dtype="half", + max_logprobs=5, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs) as vllm_model: + + vllm_sampling_params = SamplingParams(max_tokens=10, + logprobs=5, + prompt_logprobs=5, + temperature=0.0) + vllm_results = vllm_model.model.generate( + example_prompts, sampling_params=vllm_sampling_params) + + for idx, result in enumerate(vllm_results): + assert result.prompt_logprobs is not None + assert result.prompt_logprobs[0] is None + + # Compared detokenized prompts ids to original prompt. + generated_string = "" + for (prompt_token, prompt_logprobs) in zip( + result.prompt_token_ids[1:], result.prompt_logprobs[1:]): + # prompt_logprobs is a dict of the token_id: logprob + # We select the token_id corresponding to the actual prompt + # Decoded token in the detokenized string corresponding to this + # promptt oken. + generated_string += prompt_logprobs[prompt_token].decoded_token + + breakpoint() + assert generated_string == example_prompts[idx], ( + "Detokenized prompt logprobs do not match original prompt" + ) + From 5f20d78e908258bb1343938602b34e34da910d79 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 02:40:21 +0000 Subject: [PATCH 07/21] tests passing --- tests/tokenization/test_detokenize.py | 4 +--- vllm/engine/output_processor/single_step.py | 4 +++- vllm/transformers_utils/detokenizer.py | 12 ++++++++++-- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 5f6b8ac644e..4d240c3e925 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -222,8 +222,7 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], @pytest.mark.parametrize("model", ["facebook/opt-125m"]) -# @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) -@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1]) def test_decode_logprobs_regression( vllm_runner, model, @@ -267,7 +266,6 @@ def test_decode_logprobs_regression( # promptt oken. generated_string += prompt_logprobs[prompt_token].decoded_token - breakpoint() assert generated_string == example_prompts[idx], ( "Detokenized prompt logprobs do not match original prompt" ) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 4d030943474..b76734d2c35 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -73,7 +73,9 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, if seq_group.sampling_params.detokenize and self.detokenizer: self.detokenizer.decode_prompt_logprobs_inplace( - seq_group, prompt_logprobs) + seq_group, prompt_logprobs, + position_offset=len(seq_group.prompt_logprobs)) + seq_group.prompt_logprobs.extend(prompt_logprobs) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index e8e53f4946e..7f16b0f014f 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -23,12 +23,15 @@ def get_tokenizer_for_seq(self, def decode_prompt_logprobs_inplace( self, seq_group: SequenceGroup, - prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None: + prompt_logprobs: List[Optional[Dict[int, Logprob]]], + position_offset: int) -> None: """Decodes the logprobs for the prompt of a sequence group. Args: seq_group: The sequence group to decode. prompt_logprobs: The logprobs to decode. + position_offset: Offset of the first index of the logprobs + relative to the start of the sequence (for chunked prefill). Returns: The prompt logprobs with the decoded tokens. @@ -47,8 +50,13 @@ def decode_prompt_logprobs_inplace( next_iter_tokens: List[str] = [] prev_tokens = None - for token_position, prompt_logprobs_for_token in enumerate( + for token_position_in_logprob, prompt_logprobs_for_token in enumerate( prompt_logprobs): + + # Absolute token position equals the index in the logprobs + # list plus the offset of the entire logprobs list relative + # to the start of the sequence. + token_position = token_position_in_logprob + position_offset if not prompt_logprobs_for_token: continue for token_id, sample_logprob in prompt_logprobs_for_token.items(): From 7fd14b494137e7c26e289f8b588efa09076b510e Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 02:40:37 +0000 Subject: [PATCH 08/21] format --- tests/tokenization/test_detokenize.py | 25 +++++++++------------ vllm/engine/output_processor/single_step.py | 11 ++++----- vllm/transformers_utils/detokenizer.py | 12 +++++----- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 4d240c3e925..65089ba4f78 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -139,7 +139,6 @@ def create_dummy_logprobs( } for token_id in complete_sequence_token_ids] - def create_dummy_prompt_logprobs( complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]: # logprob for the first prompt token is not defined. @@ -237,13 +236,12 @@ def test_decode_logprobs_regression( max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) max_num_batched_tokens = chunked_prefill_token_size - with vllm_runner( - model, - dtype="half", - max_logprobs=5, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs) as vllm_model: + with vllm_runner(model, + dtype="half", + max_logprobs=5, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs) as vllm_model: vllm_sampling_params = SamplingParams(max_tokens=10, logprobs=5, @@ -255,11 +253,12 @@ def test_decode_logprobs_regression( for idx, result in enumerate(vllm_results): assert result.prompt_logprobs is not None assert result.prompt_logprobs[0] is None - + # Compared detokenized prompts ids to original prompt. generated_string = "" - for (prompt_token, prompt_logprobs) in zip( - result.prompt_token_ids[1:], result.prompt_logprobs[1:]): + for (prompt_token, + prompt_logprobs) in zip(result.prompt_token_ids[1:], + result.prompt_logprobs[1:]): # prompt_logprobs is a dict of the token_id: logprob # We select the token_id corresponding to the actual prompt # Decoded token in the detokenized string corresponding to this @@ -267,6 +266,4 @@ def test_decode_logprobs_regression( generated_string += prompt_logprobs[prompt_token].decoded_token assert generated_string == example_prompts[idx], ( - "Detokenized prompt logprobs do not match original prompt" - ) - + "Detokenized prompt logprobs do not match original prompt") diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index b76734d2c35..4851897ddef 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -60,11 +60,11 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, assert len(outputs) == 1, ("Single step should only has 1 output.") output = outputs[0] prompt_logprobs = output.prompt_logprobs - + # If this is the first (or only) "chunk" of the prefill, we need # to prepend None to the list of prompt logprobs. The reason for this - # is that for N prompt tokens, the Sampler will generate N-1 total - # prompt logprobs during prefill since the token at idx 0 will not + # is that for N prompt tokens, the Sampler will generate N-1 total + # prompt logprobs during prefill since the token at idx 0 will not # have a logprob associated with it. if prompt_logprobs is not None: if not seq_group.prompt_logprobs: @@ -73,9 +73,10 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, if seq_group.sampling_params.detokenize and self.detokenizer: self.detokenizer.decode_prompt_logprobs_inplace( - seq_group, prompt_logprobs, + seq_group, + prompt_logprobs, position_offset=len(seq_group.prompt_logprobs)) - + seq_group.prompt_logprobs.extend(prompt_logprobs) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 7f16b0f014f..cc9a971301a 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -21,10 +21,10 @@ def get_tokenizer_for_seq(self, """Returns the HF tokenizer to use for a given sequence.""" return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request) - def decode_prompt_logprobs_inplace( - self, seq_group: SequenceGroup, - prompt_logprobs: List[Optional[Dict[int, Logprob]]], - position_offset: int) -> None: + def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, + prompt_logprobs: List[Optional[Dict[ + int, Logprob]]], + position_offset: int) -> None: """Decodes the logprobs for the prompt of a sequence group. Args: @@ -52,8 +52,8 @@ def decode_prompt_logprobs_inplace( for token_position_in_logprob, prompt_logprobs_for_token in enumerate( prompt_logprobs): - - # Absolute token position equals the index in the logprobs + + # Absolute token position equals the index in the logprobs # list plus the offset of the entire logprobs list relative # to the start of the sequence. token_position = token_position_in_logprob + position_offset From 6a3bd6380159d5d3d0f3601668daab7a061d0286 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 02:47:58 +0000 Subject: [PATCH 09/21] add co-author From badd8b5c0a61d73bc3a96a59e4915ea311648bc5 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 02:48:14 +0000 Subject: [PATCH 10/21] add co-author From 2c38423a1e794f7a84befd01faf68445420d2a88 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 02:49:45 +0000 Subject: [PATCH 11/21] add co-author Co-authored-by: Zifei Tong From af7a0da6f30ad1da316b23d30cd7bb36244dae7c Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Mon, 8 Jul 2024 22:57:15 -0400 Subject: [PATCH 12/21] Update test_detokenize.py --- tests/tokenization/test_detokenize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 65089ba4f78..1781060e0f9 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -262,7 +262,7 @@ def test_decode_logprobs_regression( # prompt_logprobs is a dict of the token_id: logprob # We select the token_id corresponding to the actual prompt # Decoded token in the detokenized string corresponding to this - # promptt oken. + # prompt token. generated_string += prompt_logprobs[prompt_token].decoded_token assert generated_string == example_prompts[idx], ( From 4f1f8ff6e97923af6619939c7f9c814fa1efba0f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 16:08:14 +0000 Subject: [PATCH 13/21] finally got tests passing --- tests/tokenization/test_detokenize.py | 56 +++++++++++++------------- vllm/transformers_utils/detokenizer.py | 3 +- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 65089ba4f78..997979350cb 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -141,8 +141,10 @@ def create_dummy_logprobs( def create_dummy_prompt_logprobs( complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]: - # logprob for the first prompt token is not defined. - return create_dummy_logprobs(complete_sequence_token_ids)[1:] + # logprob for the first prompt token is None. + logprobs = [None] + logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:]) + return logprobs @pytest.mark.parametrize("complete_sequence", TRUTH) @@ -183,12 +185,10 @@ def test_decode_sequence_logprobs(complete_sequence: str, @pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -@pytest.mark.parametrize("skip_special_tokens", [True]) def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], - detokenizer: Detokenizer, - skip_special_tokens: bool): + detokenizer: Detokenizer): """Verify Detokenizer decodes prompt logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, + sampling_params = SamplingParams(skip_special_tokens=True, prompt_logprobs=1) # Run sequentially. @@ -198,31 +198,33 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], sampling_params=sampling_params, arrival_time=0.0) dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids) - detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs) - decoded_prompt_logprobs = dummy_logprobs - - if skip_special_tokens: - # decoded_prompt_logprobs doesn't contain the first token. - token_ids = complete_sequence_token_ids[1:] - tokenzier = detokenizer.get_tokenizer_for_seq(seq) - text = tokenzier.decode(token_ids, - skip_special_tokens=skip_special_tokens) - # Text for logprobs for the chosen token should be the same as the - # prompt text. Note that this will only be true if we skip - # special tokens. - assert text == "".join([ - logprobs[token_id].decoded_token - for token_id, logprobs in zip(token_ids, decoded_prompt_logprobs) - ]) - assert text != "".join([ - logprobs[token_id + 1].decoded_token - for token_id, logprobs in zip(token_ids, decoded_prompt_logprobs) - ]) + detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs, + position_offset=0) + # First logprob is None. + decoded_prompt_logprobs = dummy_logprobs[1:] + + # decoded_prompt_logprobs doesn't contain the first token. + token_ids = complete_sequence_token_ids + tokenzier = detokenizer.get_tokenizer_for_seq(seq) + text_full = tokenzier.decode(token_ids, skip_special_tokens=True) + text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True) + text = text_full[len(text_first):] + + # Text for logprobs for the chosen token should be the same as the + # prompt text. Note that the first logprob is None. + assert text == "".join([ + logprobs[token_id].decoded_token + for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) + ]) + assert text != "".join([ + logprobs[token_id + 1].decoded_token + for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) + ]) @pytest.mark.parametrize("model", ["facebook/opt-125m"]) @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1]) -def test_decode_logprobs_regression( +def test_decode_prompt_logprobs_chunked_prefill( vllm_runner, model, chunked_prefill_token_size: int, diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index cc9a971301a..955065ae48d 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -75,7 +75,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, spaces_between_special_tokens=prms. spaces_between_special_tokens, ) - + sample_logprob.decoded_token = new_text # Use the offsets & prev tokens corresponding to @@ -86,6 +86,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, next_iter_read_offset = new_read_offset next_iter_tokens = new_tokens + # Advance to the next token position. prefix_offset = next_iter_prefix_offset read_offset = next_iter_read_offset From dbe97475dbe8a1bf55ecebdbabaa1da0c6dddd44 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 16:10:38 +0000 Subject: [PATCH 14/21] nits --- vllm/transformers_utils/detokenizer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 955065ae48d..cc9a971301a 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -75,7 +75,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, spaces_between_special_tokens=prms. spaces_between_special_tokens, ) - + sample_logprob.decoded_token = new_text # Use the offsets & prev tokens corresponding to @@ -86,7 +86,6 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, next_iter_read_offset = new_read_offset next_iter_tokens = new_tokens - # Advance to the next token position. prefix_offset = next_iter_prefix_offset read_offset = next_iter_read_offset From 2fdbc01d479527f3dc9a5216100b2c6dabbd2f25 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 16:32:00 +0000 Subject: [PATCH 15/21] format --- tests/tokenization/test_detokenize.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 0cffebd4af8..6c8ff78bda5 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Any, Dict, List, Optional import pytest from transformers import AutoTokenizer @@ -140,9 +140,10 @@ def create_dummy_logprobs( def create_dummy_prompt_logprobs( - complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]: + complete_sequence_token_ids: List[int] +) -> List[Optional[Dict[int, Any]]]: # logprob for the first prompt token is None. - logprobs = [None] + logprobs: List[Optional[Dict[int, Any]]] = [None] logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:]) return logprobs @@ -198,10 +199,12 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], sampling_params=sampling_params, arrival_time=0.0) dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids) - detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs, + detokenizer.decode_prompt_logprobs_inplace(seq_group, + dummy_logprobs, position_offset=0) # First logprob is None. - decoded_prompt_logprobs = dummy_logprobs[1:] + decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[ + 1:] # type: ignore # decoded_prompt_logprobs doesn't contain the first token. token_ids = complete_sequence_token_ids From 2084cf5e807212d9260dc12a674550099fe21965 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 9 Jul 2024 14:40:37 -0400 Subject: [PATCH 16/21] Update test_detokenize.py oom in automation --- tests/tokenization/test_detokenize.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 6c8ff78bda5..e06eece25c1 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -244,6 +244,7 @@ def test_decode_prompt_logprobs_chunked_prefill( with vllm_runner(model, dtype="half", max_logprobs=5, + gpu_memory_utilization=0.7, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, max_num_seqs=max_num_seqs) as vllm_model: From b2b003703acec9348276289457b00cebc9cf9831 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 9 Jul 2024 15:50:02 -0400 Subject: [PATCH 17/21] Update test_detokenize.py try wait for gpu memory to clear --- tests/tokenization/test_detokenize.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index e06eece25c1..70e25838abf 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional +import torch import pytest from transformers import AutoTokenizer @@ -8,6 +9,8 @@ detokenize_incrementally) from vllm.transformers_utils.tokenizer_group import get_tokenizer_group +from ..utils import wait_for_gpu_memory_to_clear + TRUTH = [ "Hello here, this is a simple test", "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa @@ -233,6 +236,12 @@ def test_decode_prompt_logprobs_chunked_prefill( chunked_prefill_token_size: int, example_prompts, ): + wait_for_gpu_memory_to_clear( + devices=list(range(torch.cuda.device_count())), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) + max_num_seqs = 256 enable_chunked_prefill = False max_num_batched_tokens = None @@ -240,11 +249,10 @@ def test_decode_prompt_logprobs_chunked_prefill( enable_chunked_prefill = True max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) max_num_batched_tokens = chunked_prefill_token_size - + with vllm_runner(model, dtype="half", max_logprobs=5, - gpu_memory_utilization=0.7, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, max_num_seqs=max_num_seqs) as vllm_model: From e52fd0bee612380f2fce31ac7a1482798b241a3a Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 9 Jul 2024 20:15:17 +0000 Subject: [PATCH 18/21] format --- tests/tokenization/test_detokenize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 70e25838abf..102f68fa4ba 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional -import torch import pytest +import torch from transformers import AutoTokenizer from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup @@ -249,7 +249,7 @@ def test_decode_prompt_logprobs_chunked_prefill( enable_chunked_prefill = True max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) max_num_batched_tokens = chunked_prefill_token_size - + with vllm_runner(model, dtype="half", max_logprobs=5, From a4f382f551aafe83f004dc03db35af6d87f09545 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 9 Jul 2024 17:22:12 -0400 Subject: [PATCH 19/21] Update test_detokenize.py try again --- tests/tokenization/test_detokenize.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 102f68fa4ba..f4551ed42ef 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional import pytest -import torch from transformers import AutoTokenizer from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup @@ -9,8 +8,6 @@ detokenize_incrementally) from vllm.transformers_utils.tokenizer_group import get_tokenizer_group -from ..utils import wait_for_gpu_memory_to_clear - TRUTH = [ "Hello here, this is a simple test", "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa @@ -236,12 +233,6 @@ def test_decode_prompt_logprobs_chunked_prefill( chunked_prefill_token_size: int, example_prompts, ): - wait_for_gpu_memory_to_clear( - devices=list(range(torch.cuda.device_count())), - threshold_bytes=2 * 2**30, - timeout_s=60, - ) - max_num_seqs = 256 enable_chunked_prefill = False max_num_batched_tokens = None @@ -253,6 +244,7 @@ def test_decode_prompt_logprobs_chunked_prefill( with vllm_runner(model, dtype="half", max_logprobs=5, + gpu_memory_utilization=0.5, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, max_num_seqs=max_num_seqs) as vllm_model: From 97a8685539f5defbfd00899f31bfd7540d84c73e Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 11 Jul 2024 19:43:40 +0000 Subject: [PATCH 20/21] updated buildkite --- .buildkite/test-pipeline.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8013fbb642b..d1185e4b1ac 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -87,7 +87,9 @@ steps: - label: Engine Test mirror_hardwares: [amd] - command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py + commands: + - pytest -v -s engine test_sequence.py test_config.py test_logger.py + - pytest -v -s tokenization - label: Entrypoints Test mirror_hardwares: [amd] From 7b44f2d7fa9697b586c14e58d98d323893c631f5 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 11 Jul 2024 16:00:34 -0400 Subject: [PATCH 21/21] Update test-pipeline.yaml --- .buildkite/test-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d1185e4b1ac..e09122ba61f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -89,6 +89,7 @@ steps: mirror_hardwares: [amd] commands: - pytest -v -s engine test_sequence.py test_config.py test_logger.py + # OOM in the CI unless we run this separately - pytest -v -s tokenization - label: Entrypoints Test