Skip to content

Commit dc99053

Browse files
authored
[V1][Spec Decode] Eagle unit tests (#17350)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
1 parent ebab1ac commit dc99053

File tree

2 files changed

+344
-0
lines changed

2 files changed

+344
-0
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from unittest import mock
4+
5+
import pytest
6+
import torch
7+
8+
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
9+
ParallelConfig, SchedulerConfig, SpeculativeConfig,
10+
VllmConfig)
11+
from vllm.v1.spec_decode.eagle import EagleProposer
12+
13+
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
14+
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
15+
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
16+
17+
18+
def _create_proposer(method: str, k: int) -> EagleProposer:
19+
model_config = ModelConfig(model=model_dir,
20+
task="generate",
21+
max_model_len=100,
22+
tokenizer=model_dir,
23+
tokenizer_mode="auto",
24+
dtype="auto",
25+
seed=None,
26+
trust_remote_code=False)
27+
28+
# Choose model directory based on method
29+
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir
30+
31+
speculative_config = SpeculativeConfig(
32+
target_model_config=model_config,
33+
target_parallel_config=ParallelConfig(),
34+
model=draft_model_dir,
35+
method=method,
36+
num_speculative_tokens=k,
37+
)
38+
39+
vllm_config = VllmConfig(model_config=model_config,
40+
cache_config=CacheConfig(),
41+
speculative_config=speculative_config,
42+
device_config=DeviceConfig(device="cuda"),
43+
parallel_config=ParallelConfig(),
44+
load_config=LoadConfig(),
45+
scheduler_config=SchedulerConfig())
46+
47+
return EagleProposer(vllm_config=vllm_config, device='cuda')
48+
49+
50+
def test_prepare_inputs():
51+
"""
52+
cu_target_query_lens: [0, a, a + b, a + b + c]
53+
num_rejected_tokens: [n1, n2, n3]
54+
num_tokens_per_req: [a - n1, b - n2, c - n3]
55+
cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
56+
token_indices: [0, 1, ..., a - n1 - 1,
57+
a, a + 1, ..., a + b - n2 - 1,
58+
a + b, a + b + 1, ..., a + b + c - n3 - 1]
59+
"""
60+
device = torch.device('cuda')
61+
62+
# a = 4, b = 7, c = 5
63+
# n1 = 1, n2 = 3, n3 = 2
64+
65+
# Cumulative lengths: [0, 4, 11, 16]
66+
cu_target_query_lens = torch.tensor([0, 4, 11, 16],
67+
dtype=torch.int32,
68+
device=device)
69+
70+
# Rejected tokens per request: [1, 3, 2]
71+
num_rejected_tokens = torch.tensor([1, 3, 2],
72+
dtype=torch.int32,
73+
device=device)
74+
75+
# Expected calculations:
76+
# query_len_per_req = [4, 7, 5]
77+
# num_tokens_per_req = [3, 4, 3] (after subtracting rejected tokens)
78+
# Expected cumulative counts: [0, 3, 7, 10]
79+
expected_cu_num_tokens = torch.tensor([0, 3, 7, 10],
80+
dtype=torch.int32,
81+
device=device)
82+
83+
# Expected token indices (mapped from original positions):
84+
# First request: indices 0, 1, 2 (keeping first 3 from positions 0-3)
85+
# Second request: indices 4, 5, 6, 7 (keeping first 4 from positions 4-10)
86+
# Third request: indices 11, 12, 13 (keeping first 3 from positions 11-15)
87+
expected_token_indices = torch.tensor(
88+
[
89+
0,
90+
1,
91+
2, # First request: 3 tokens (4-1)
92+
4,
93+
5,
94+
6,
95+
7, # Second request: 4 tokens (7-3)
96+
11,
97+
12,
98+
13 # Third request: 3 tokens (5-2)
99+
],
100+
dtype=torch.int32,
101+
device=device)
102+
103+
cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
104+
cu_target_query_lens, num_rejected_tokens)
105+
106+
assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
107+
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
108+
assert torch.equal(token_indices, expected_token_indices)
109+
110+
111+
@pytest.mark.parametrize(
112+
"method,proposer_helper,draft_model_dir,target_attribute_path", [
113+
("eagle", lambda k: _create_proposer("eagle", k), eagle_dir,
114+
('lm_head', )),
115+
("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir,
116+
('model', 'embed_tokens')),
117+
])
118+
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
119+
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry')
120+
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader')
121+
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
122+
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
123+
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
124+
mock_registry, mock_get_layers, method, proposer_helper,
125+
draft_model_dir, target_attribute_path):
126+
127+
# Setup mock for model class
128+
mock_model_cls = mock.MagicMock()
129+
mock_registry.resolve_model_cls.return_value = (mock_model_cls,
130+
"test_arch")
131+
132+
# Create a real context manager for mocks
133+
class MockContextManager:
134+
135+
def __init__(self):
136+
pass
137+
138+
def __enter__(self):
139+
return None
140+
141+
def __exit__(self, exc_type, exc_val, exc_tb):
142+
return False
143+
144+
# Make the mocks return actual context manager objects
145+
mock_set_dtype.return_value = MockContextManager()
146+
mock_set_config.return_value = MockContextManager()
147+
148+
# Setup mocks for attention layers
149+
target_attn_layers = {
150+
"target_attn_1": mock.MagicMock(),
151+
"target_attn_2": mock.MagicMock()
152+
}
153+
# Draft model has one extra attention layer compared to target model
154+
all_attn_layers = {
155+
**target_attn_layers, "draft_extra_attn": mock.MagicMock()
156+
}
157+
158+
# Make mock_get_layers return different values for each call
159+
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
160+
161+
# Setup model loader mock
162+
mock_loader = mock.MagicMock()
163+
mock_get_loader.return_value = mock_loader
164+
165+
# Setup model mock
166+
mock_model = mock.MagicMock()
167+
mock_model_cls.return_value = mock_model
168+
mock_model.to.return_value = mock_model
169+
170+
# Configure mock to test the attribute sharing path
171+
if method == "eagle":
172+
# For eagle, test the lm_head path
173+
mock_model.load_weights.return_value = {
174+
"model.embed_tokens.weight": torch.zeros(1)
175+
}
176+
else:
177+
# For eagle3, test the embed_tokens path
178+
mock_model.load_weights.return_value = {}
179+
180+
# Setup target model with the appropriate attributes
181+
target_model = mock.MagicMock()
182+
183+
# Create the necessary attributes on the target model
184+
current_obj = target_model
185+
for i, attr in enumerate(target_attribute_path):
186+
if i == len(target_attribute_path) - 1:
187+
# Set the last attribute in the path to a MagicMock
188+
setattr(current_obj, attr, mock.MagicMock())
189+
else:
190+
# Create intermediate objects if needed
191+
setattr(current_obj, attr, mock.MagicMock())
192+
current_obj = getattr(current_obj, attr)
193+
194+
# Create proposer using the helper function
195+
proposer = proposer_helper(k=8)
196+
197+
# Call the method under test
198+
proposer.load_model(target_model)
199+
200+
# Verify common interactions
201+
mock_get_loader.assert_called_once()
202+
mock_model_cls.assert_called_once()
203+
mock_model.to.assert_called_once()
204+
mock_model.load_weights.assert_called_once()
205+
206+
# Verify the loader was called with the right config
207+
mock_get_loader.assert_called_once_with(proposer.vllm_config.load_config)
208+
209+
# Verify the specific attribute sharing based on the method
210+
if method == "eagle":
211+
assert proposer.model.lm_head == target_model.lm_head
212+
else:
213+
assert proposer.model.model.embed_tokens == \
214+
target_model.model.embed_tokens
215+
216+
217+
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
218+
def test_propose(num_speculative_tokens):
219+
# Use GPU device
220+
device = torch.device('cuda')
221+
222+
# Setup test parameters
223+
batch_size = 2
224+
seq_len_1 = 5
225+
seq_len_2 = 3
226+
total_tokens = seq_len_1 + seq_len_2
227+
vocab_size = 100
228+
229+
# Create proposer first so we can use its actual hidden_size
230+
proposer = _create_proposer("eagle", num_speculative_tokens)
231+
# Get the hidden_size from the proposer to ensure consistency
232+
hidden_size = proposer.hidden_size
233+
234+
# Helper to create deterministic logits that will produce specific tokens
235+
def create_deterministic_logits(token_ids):
236+
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
237+
for i, token_id in enumerate(token_ids):
238+
logits[i, token_id] = 100.0
239+
return logits
240+
241+
# We mock a model that returns deterministic logits
242+
# Sequence 1: 42, 43, 44, ...
243+
# Sequence 2: 60, 61, 62, ...
244+
base_token_ids = [42, 60]
245+
246+
# Skip loading the model and replace it with a mock directly
247+
# Create the mock model with deterministic outputs
248+
model_mock = mock.MagicMock()
249+
250+
# Setup for model forward calls
251+
forward_returns = []
252+
for i in range(num_speculative_tokens):
253+
if i == 0:
254+
# First call uses all tokens
255+
h_logits = torch.zeros(total_tokens, hidden_size, device=device)
256+
h_states = torch.zeros(total_tokens, hidden_size, device=device)
257+
else:
258+
# Subsequent calls use batch_size tokens
259+
h_logits = torch.zeros(batch_size, hidden_size, device=device)
260+
h_states = torch.zeros(batch_size, hidden_size, device=device)
261+
forward_returns.append((h_logits, h_states))
262+
263+
# For single token case, we only need the first item;
264+
# for multi-token, we need the sequence
265+
if num_speculative_tokens == 1:
266+
model_mock.return_value = forward_returns[0]
267+
else:
268+
model_mock.side_effect = forward_returns
269+
270+
# Setup for compute_logits calls
271+
logits_returns = []
272+
for i in range(num_speculative_tokens):
273+
# For each call, increment the base token IDs
274+
current_tokens = [base_id + i for base_id in base_token_ids]
275+
logits_returns.append(create_deterministic_logits(current_tokens))
276+
277+
if num_speculative_tokens == 1:
278+
model_mock.compute_logits.return_value = logits_returns[0]
279+
else:
280+
model_mock.compute_logits.side_effect = logits_returns
281+
282+
# Assign the mock to the proposer
283+
proposer.model = model_mock
284+
285+
# Create input tensors
286+
cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens],
287+
dtype=torch.int32,
288+
device=device)
289+
290+
target_token_ids = torch.randint(0,
291+
vocab_size, (total_tokens, ),
292+
device=device)
293+
target_positions = torch.cat([
294+
torch.arange(seq_len_1, device=device),
295+
torch.arange(seq_len_2, device=device)
296+
])
297+
target_hidden_states = torch.randn(total_tokens,
298+
hidden_size,
299+
device=device)
300+
target_slot_mapping = torch.randint(0,
301+
100, (total_tokens, ),
302+
device=device)
303+
next_token_ids = torch.randint(0,
304+
vocab_size, (batch_size, ),
305+
dtype=torch.int32,
306+
device=device)
307+
block_table = torch.randint(0, 10, (batch_size, 10), device=device)
308+
309+
sampling_metadata = mock.MagicMock()
310+
311+
# Call the method under test
312+
result = proposer.propose(target_token_ids=target_token_ids,
313+
target_positions=target_positions,
314+
target_hidden_states=target_hidden_states,
315+
target_slot_mapping=target_slot_mapping,
316+
next_token_ids=next_token_ids,
317+
cu_num_tokens=cu_num_tokens,
318+
block_table=block_table,
319+
sampling_metadata=sampling_metadata)
320+
321+
assert result.shape == (batch_size, num_speculative_tokens)
322+
323+
# Create expected tokens based on our token pattern
324+
if num_speculative_tokens == 1:
325+
# Example for num_speculative_tokens=1:
326+
# [[42], [60]]
327+
expected_tokens = torch.tensor(
328+
[[base_token_ids[0]], [base_token_ids[1]]], device=device)
329+
else:
330+
# Example for num_speculative_tokens=3:
331+
# [[42, 43, 44], [60, 61, 62]]
332+
expected_tokens = torch.zeros((batch_size, num_speculative_tokens),
333+
dtype=torch.int64,
334+
device=device)
335+
for i in range(batch_size):
336+
for j in range(num_speculative_tokens):
337+
expected_tokens[i, j] = base_token_ids[i] + j
338+
339+
# Verify all tokens match our expectations
340+
assert torch.equal(result, expected_tokens)

vllm/v1/spec_decode/eagle.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ def propose(
223223
hidden_states = hidden_states[:batch_size]
224224
logits = self.model.compute_logits(last_hidden_states[:batch_size],
225225
None)
226+
227+
# TODO(wenlong): get more than one token for tree attention
226228
draft_token_ids = logits.argmax(dim=-1)
227229
draft_token_ids_list.append(draft_token_ids)
228230

@@ -251,6 +253,8 @@ def prepare_inputs(
251253
# [a, b, c] -> [a - n1, b - n2, c - n3]
252254
num_tokens_per_req = query_len_per_req - num_rejected_tokens
253255

256+
# [a - n1, b - n2, c - n3] ->
257+
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
254258
cu_num_tokens = torch.empty_like(cu_target_query_lens)
255259
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
256260
cu_num_tokens[0] = 0

0 commit comments

Comments
 (0)