Skip to content

Commit 99484ae

Browse files
committed
adding acceptance rate test for large output length
1 parent 9587b05 commit 99484ae

File tree

2 files changed

+56
-9
lines changed

2 files changed

+56
-9
lines changed

tests/spec_decode/e2e/conftest.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
288288
ensure_all_accepted=ensure_all_accepted)
289289

290290

291-
def run_equality_correctness_test(baseline_llm_generator,
292-
test_llm_generator,
293-
batch_size,
294-
max_output_len,
295-
force_output_len: bool,
296-
temperature: float,
297-
seeded: bool,
298-
print_tokens: bool = False,
299-
ensure_all_accepted: bool = False):
291+
def run_equality_correctness_test(
292+
baseline_llm_generator,
293+
test_llm_generator,
294+
batch_size,
295+
max_output_len,
296+
force_output_len: bool,
297+
temperature: float,
298+
seeded: bool,
299+
print_tokens: bool = False,
300+
ensure_all_accepted: bool = False,
301+
expected_acceptance_rate: Optional[float] = None):
300302
"""Helper method that compares the outputs of both the baseline LLM and
301303
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
302304
the same when temperature is zero (or when temperature is > 0 and seeded).
@@ -359,3 +361,6 @@ def run_equality_correctness_test(baseline_llm_generator,
359361

360362
if ensure_all_accepted:
361363
assert acceptance_rate == 1.0
364+
365+
if expected_acceptance_rate is not None:
366+
assert acceptance_rate >= expected_acceptance_rate - 1e-2

tests/spec_decode/e2e/test_mlp_correctness.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,48 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
8282
force_output_len=True)
8383

8484

85+
@pytest.mark.parametrize(
86+
"common_llm_kwargs",
87+
[{
88+
# Skip cuda graph recording for fast test.
89+
"enforce_eager": True,
90+
91+
# Required for spec decode.
92+
"use_v2_block_manager": True,
93+
94+
# Print spec metrics.
95+
"disable_log_stats": False,
96+
97+
# Precision
98+
"dtype": PRECISION,
99+
100+
# Main model
101+
"model": MAIN_MODEL,
102+
}])
103+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
104+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
105+
@pytest.mark.parametrize("test_llm_kwargs", [
106+
{
107+
"speculative_model": SPEC_MODEL,
108+
},
109+
])
110+
@pytest.mark.parametrize("output_len", [2048])
111+
@pytest.mark.parametrize("batch_size", [1, 32])
112+
@pytest.mark.parametrize("seed", [1])
113+
def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
114+
batch_size: int, output_len: int):
115+
"""Verify acceptance rate with different batch size and large output
116+
length."""
117+
run_equality_correctness_test(baseline_llm_generator,
118+
test_llm_generator,
119+
batch_size,
120+
max_output_len=output_len,
121+
temperature=0.0,
122+
seeded=True,
123+
force_output_len=True,
124+
expected_acceptance_rate=0.6)
125+
126+
85127
@pytest.mark.parametrize(
86128
"common_llm_kwargs",
87129
[{

0 commit comments

Comments
 (0)