|
3 | 3 | Note: these tests will only pass on L4 GPU.
|
4 | 4 | """
|
5 | 5 | import os
|
6 |
| -from typing import List |
| 6 | +from typing import Optional |
7 | 7 |
|
8 | 8 | import pytest
|
9 |
| -import torch |
10 |
| -from transformers import AutoTokenizer |
11 | 9 |
|
| 10 | +from tests.kernels.utils import override_backend_env_variable |
12 | 11 | from tests.quantization.utils import is_quant_method_supported
|
13 |
| -from vllm import LLM, SamplingParams |
14 | 12 |
|
15 |
| -os.environ["TOKENIZERS_PARALLELISM"] = "true" |
16 |
| - |
17 |
| -MAX_MODEL_LEN = 1024 |
18 |
| - |
19 |
| -MODELS = [ |
20 |
| - "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", |
21 |
| - "meta-llama/Meta-Llama-3-8B-Instruct", |
22 |
| -] |
| 13 | +from ..models.utils import check_logprobs_close |
23 | 14 |
|
24 |
| -EXPECTED_STRS_MAP = { |
25 |
| - "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": { |
26 |
| - "auto": [ |
27 |
| - 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', |
28 |
| - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', |
29 |
| - 'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both', |
30 |
| - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', |
31 |
| - 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', |
32 |
| - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', |
33 |
| - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', |
34 |
| - 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no' |
35 |
| - ], |
36 |
| - "fp8": [ |
37 |
| - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', |
38 |
| - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', |
39 |
| - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', |
40 |
| - 'A neural network is a complex system made up of several basic components that work together to enable it to', |
41 |
| - 'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like', |
42 |
| - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', |
43 |
| - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', |
44 |
| - 'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk' |
45 |
| - ] |
46 |
| - }, |
47 |
| - "meta-llama/Meta-Llama-3-8B-Instruct": { |
48 |
| - "auto": [ |
49 |
| - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', |
50 |
| - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', |
51 |
| - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', |
52 |
| - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', |
53 |
| - 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', |
54 |
| - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', |
55 |
| - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', |
56 |
| - 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' |
57 |
| - ], |
58 |
| - "fp8": [ |
59 |
| - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', |
60 |
| - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', |
61 |
| - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', |
62 |
| - 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', |
63 |
| - 'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest', |
64 |
| - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', |
65 |
| - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', |
66 |
| - 'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu' |
67 |
| - ] |
68 |
| - }, |
69 |
| -} |
| 15 | +os.environ["TOKENIZERS_PARALLELISM"] = "true" |
70 | 16 |
|
71 | 17 |
|
72 |
| -# This test compares against golden strings for exact match since |
73 |
| -# there is no baseline implementation to compare against |
74 |
| -# and is unstable w.r.t specifics of the fp8 implementation or |
75 |
| -# the hardware being run on. |
76 |
| -# Disabled to prevent it from breaking the build |
77 |
| -@pytest.mark.skip( |
78 |
| - reason= |
79 |
| - "Prevent unstable test based on golden strings from breaking the build.") |
80 | 18 | @pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
81 | 19 | reason="fp8 is not supported on this GPU type.")
|
82 |
| -@pytest.mark.parametrize("model_name", MODELS) |
83 |
| -@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) |
84 |
| -def test_models(example_prompts, model_name, kv_cache_dtype) -> None: |
85 |
| - model = LLM(model=model_name, |
86 |
| - max_model_len=MAX_MODEL_LEN, |
87 |
| - trust_remote_code=True, |
88 |
| - enforce_eager=True, |
89 |
| - quantization="fp8", |
90 |
| - kv_cache_dtype=kv_cache_dtype) |
| 20 | +@pytest.mark.parametrize( |
| 21 | + "kv_cache_dtype,base_model,test_model,scale_path", |
| 22 | + [ |
| 23 | + # Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors. |
| 24 | + ("fp8_e4m3", "meta-llama/Meta-Llama-3-8B-Instruct", |
| 25 | + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", None), |
| 26 | + # Test FP16 checkpoint w. fp8_e5m2 kv-cache. |
| 27 | + ("fp8_e5m2", "meta-llama/Meta-Llama-3-8B-Instruct", |
| 28 | + "meta-llama/Meta-Llama-3-8B-Instruct", None), |
| 29 | + # Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json. |
| 30 | + ("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf", |
| 31 | + "meta-llama/Llama-2-7b-chat-hf", |
| 32 | + "./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json") |
| 33 | + ]) |
| 34 | +# Due to low-precision numerical divergence, we only test logprob of 4 tokens |
| 35 | +@pytest.mark.parametrize("max_tokens", [4]) |
| 36 | +@pytest.mark.parametrize("enforce_eager", [False, True]) |
| 37 | +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) |
| 38 | +# NOTE: Increasing this in this suite will fail CI because we currently cannot |
| 39 | +# reset distributed env properly. Use a value > 1 just when you test. |
| 40 | +@pytest.mark.parametrize("tensor_parallel_size", [1]) |
| 41 | +# Due to low-precision numerical divergence, this test is too sensitive for |
| 42 | +# the async postprocessor |
| 43 | +@pytest.mark.parametrize("disable_async_output_proc", [True]) |
| 44 | +def test_models( |
| 45 | + vllm_runner, |
| 46 | + example_prompts, |
| 47 | + kv_cache_dtype: str, |
| 48 | + base_model: str, |
| 49 | + test_model: str, |
| 50 | + scale_path: Optional[str], |
| 51 | + max_tokens: int, |
| 52 | + enforce_eager: bool, |
| 53 | + backend: str, |
| 54 | + tensor_parallel_size: int, |
| 55 | + disable_async_output_proc: bool, |
| 56 | + monkeypatch, |
| 57 | +) -> None: |
| 58 | + """ |
| 59 | + Only checks log probs match to cover the discrepancy in |
| 60 | + numerical sensitive kernels. |
| 61 | + """ |
| 62 | + override_backend_env_variable(monkeypatch, backend) |
| 63 | + |
| 64 | + MAX_MODEL_LEN = 1024 |
| 65 | + NUM_LOG_PROBS = 8 |
| 66 | + |
| 67 | + with vllm_runner( |
| 68 | + base_model, |
| 69 | + max_model_len=MAX_MODEL_LEN, |
| 70 | + tensor_parallel_size=tensor_parallel_size, |
| 71 | + enforce_eager=enforce_eager, |
| 72 | + kv_cache_dtype="auto", |
| 73 | + disable_async_output_proc=disable_async_output_proc, |
| 74 | + ) as vllm_model: |
| 75 | + baseline_outputs = vllm_model.generate_greedy_logprobs( |
| 76 | + example_prompts, max_tokens, NUM_LOG_PROBS) |
91 | 77 |
|
92 |
| - tokenizer = AutoTokenizer.from_pretrained(model_name) |
93 |
| - formatted_prompts = [ |
94 |
| - tokenizer.apply_chat_template([{ |
95 |
| - "role": "user", |
96 |
| - "content": prompt |
97 |
| - }], |
98 |
| - tokenize=False, |
99 |
| - add_generation_prompt=True) |
100 |
| - for prompt in example_prompts |
101 |
| - ] |
| 78 | + extra_kwargs = {} |
| 79 | + if scale_path is not None: |
| 80 | + extra_kwargs["quantization_param_path"] = scale_path |
102 | 81 |
|
103 |
| - params = SamplingParams(max_tokens=20, temperature=0) |
104 |
| - generations: List[str] = [] |
105 |
| - # Note: these need to be run 1 at a time due to numerical precision, |
106 |
| - # since the expected strs were generated this way. |
107 |
| - for prompt in formatted_prompts: |
108 |
| - outputs = model.generate(prompt, params) |
109 |
| - generations.append(outputs[0].outputs[0].text) |
110 |
| - del model |
| 82 | + with vllm_runner( |
| 83 | + test_model, |
| 84 | + max_model_len=MAX_MODEL_LEN, |
| 85 | + tensor_parallel_size=tensor_parallel_size, |
| 86 | + enforce_eager=enforce_eager, |
| 87 | + kv_cache_dtype=kv_cache_dtype, |
| 88 | + disable_async_output_proc=disable_async_output_proc, |
| 89 | + **extra_kwargs, |
| 90 | + ) as vllm_model: |
| 91 | + test_outputs = vllm_model.generate_greedy_logprobs( |
| 92 | + example_prompts, max_tokens, NUM_LOG_PROBS) |
111 | 93 |
|
112 |
| - print(model_name, kv_cache_dtype, generations) |
113 |
| - expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype] |
114 |
| - for i in range(len(example_prompts)): |
115 |
| - generated_str = generations[i] |
116 |
| - expected_str = expected_strs[i] |
117 |
| - assert expected_str == generated_str, ( |
118 |
| - f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") |
| 94 | + check_logprobs_close( |
| 95 | + outputs_0_lst=baseline_outputs, |
| 96 | + outputs_1_lst=test_outputs, |
| 97 | + name_0="fp16_kv_cache", |
| 98 | + name_1="fp8_kv_cache", |
| 99 | + ) |
0 commit comments