|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +from unittest import mock |
| 4 | + |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | + |
| 8 | +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, |
| 9 | + ParallelConfig, SchedulerConfig, SpeculativeConfig, |
| 10 | + VllmConfig) |
| 11 | +from vllm.v1.spec_decode.eagle import EagleProposer |
| 12 | + |
| 13 | +model_dir = "meta-llama/Llama-3.1-8B-Instruct" |
| 14 | +eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" |
| 15 | +eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" |
| 16 | + |
| 17 | + |
| 18 | +def _create_proposer(method: str, k: int) -> EagleProposer: |
| 19 | + model_config = ModelConfig(model=model_dir, |
| 20 | + task="generate", |
| 21 | + max_model_len=100, |
| 22 | + tokenizer=model_dir, |
| 23 | + tokenizer_mode="auto", |
| 24 | + dtype="auto", |
| 25 | + seed=None, |
| 26 | + trust_remote_code=False) |
| 27 | + |
| 28 | + # Choose model directory based on method |
| 29 | + draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir |
| 30 | + |
| 31 | + speculative_config = SpeculativeConfig( |
| 32 | + target_model_config=model_config, |
| 33 | + target_parallel_config=ParallelConfig(), |
| 34 | + model=draft_model_dir, |
| 35 | + method=method, |
| 36 | + num_speculative_tokens=k, |
| 37 | + ) |
| 38 | + |
| 39 | + vllm_config = VllmConfig(model_config=model_config, |
| 40 | + cache_config=CacheConfig(), |
| 41 | + speculative_config=speculative_config, |
| 42 | + device_config=DeviceConfig(device="cuda"), |
| 43 | + parallel_config=ParallelConfig(), |
| 44 | + load_config=LoadConfig(), |
| 45 | + scheduler_config=SchedulerConfig()) |
| 46 | + |
| 47 | + return EagleProposer(vllm_config=vllm_config, device='cuda') |
| 48 | + |
| 49 | + |
| 50 | +def test_prepare_inputs(): |
| 51 | + """ |
| 52 | + cu_target_query_lens: [0, a, a + b, a + b + c] |
| 53 | + num_rejected_tokens: [n1, n2, n3] |
| 54 | + num_tokens_per_req: [a - n1, b - n2, c - n3] |
| 55 | + cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] |
| 56 | + token_indices: [0, 1, ..., a - n1 - 1, |
| 57 | + a, a + 1, ..., a + b - n2 - 1, |
| 58 | + a + b, a + b + 1, ..., a + b + c - n3 - 1] |
| 59 | + """ |
| 60 | + device = torch.device('cuda') |
| 61 | + |
| 62 | + # a = 4, b = 7, c = 5 |
| 63 | + # n1 = 1, n2 = 3, n3 = 2 |
| 64 | + |
| 65 | + # Cumulative lengths: [0, 4, 11, 16] |
| 66 | + cu_target_query_lens = torch.tensor([0, 4, 11, 16], |
| 67 | + dtype=torch.int32, |
| 68 | + device=device) |
| 69 | + |
| 70 | + # Rejected tokens per request: [1, 3, 2] |
| 71 | + num_rejected_tokens = torch.tensor([1, 3, 2], |
| 72 | + dtype=torch.int32, |
| 73 | + device=device) |
| 74 | + |
| 75 | + # Expected calculations: |
| 76 | + # query_len_per_req = [4, 7, 5] |
| 77 | + # num_tokens_per_req = [3, 4, 3] (after subtracting rejected tokens) |
| 78 | + # Expected cumulative counts: [0, 3, 7, 10] |
| 79 | + expected_cu_num_tokens = torch.tensor([0, 3, 7, 10], |
| 80 | + dtype=torch.int32, |
| 81 | + device=device) |
| 82 | + |
| 83 | + # Expected token indices (mapped from original positions): |
| 84 | + # First request: indices 0, 1, 2 (keeping first 3 from positions 0-3) |
| 85 | + # Second request: indices 4, 5, 6, 7 (keeping first 4 from positions 4-10) |
| 86 | + # Third request: indices 11, 12, 13 (keeping first 3 from positions 11-15) |
| 87 | + expected_token_indices = torch.tensor( |
| 88 | + [ |
| 89 | + 0, |
| 90 | + 1, |
| 91 | + 2, # First request: 3 tokens (4-1) |
| 92 | + 4, |
| 93 | + 5, |
| 94 | + 6, |
| 95 | + 7, # Second request: 4 tokens (7-3) |
| 96 | + 11, |
| 97 | + 12, |
| 98 | + 13 # Third request: 3 tokens (5-2) |
| 99 | + ], |
| 100 | + dtype=torch.int32, |
| 101 | + device=device) |
| 102 | + |
| 103 | + cu_num_tokens, token_indices = EagleProposer.prepare_inputs( |
| 104 | + cu_target_query_lens, num_rejected_tokens) |
| 105 | + |
| 106 | + assert torch.equal(cu_num_tokens, expected_cu_num_tokens) |
| 107 | + assert token_indices.shape[0] == expected_cu_num_tokens[-1].item() |
| 108 | + assert torch.equal(token_indices, expected_token_indices) |
| 109 | + |
| 110 | + |
| 111 | +@pytest.mark.parametrize( |
| 112 | + "method,proposer_helper,draft_model_dir,target_attribute_path", [ |
| 113 | + ("eagle", lambda k: _create_proposer("eagle", k), eagle_dir, |
| 114 | + ('lm_head', )), |
| 115 | + ("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir, |
| 116 | + ('model', 'embed_tokens')), |
| 117 | + ]) |
| 118 | +@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') |
| 119 | +@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry') |
| 120 | +@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader') |
| 121 | +@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype') |
| 122 | +@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config') |
| 123 | +def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader, |
| 124 | + mock_registry, mock_get_layers, method, proposer_helper, |
| 125 | + draft_model_dir, target_attribute_path): |
| 126 | + |
| 127 | + # Setup mock for model class |
| 128 | + mock_model_cls = mock.MagicMock() |
| 129 | + mock_registry.resolve_model_cls.return_value = (mock_model_cls, |
| 130 | + "test_arch") |
| 131 | + |
| 132 | + # Create a real context manager for mocks |
| 133 | + class MockContextManager: |
| 134 | + |
| 135 | + def __init__(self): |
| 136 | + pass |
| 137 | + |
| 138 | + def __enter__(self): |
| 139 | + return None |
| 140 | + |
| 141 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 142 | + return False |
| 143 | + |
| 144 | + # Make the mocks return actual context manager objects |
| 145 | + mock_set_dtype.return_value = MockContextManager() |
| 146 | + mock_set_config.return_value = MockContextManager() |
| 147 | + |
| 148 | + # Setup mocks for attention layers |
| 149 | + target_attn_layers = { |
| 150 | + "target_attn_1": mock.MagicMock(), |
| 151 | + "target_attn_2": mock.MagicMock() |
| 152 | + } |
| 153 | + # Draft model has one extra attention layer compared to target model |
| 154 | + all_attn_layers = { |
| 155 | + **target_attn_layers, "draft_extra_attn": mock.MagicMock() |
| 156 | + } |
| 157 | + |
| 158 | + # Make mock_get_layers return different values for each call |
| 159 | + mock_get_layers.side_effect = [target_attn_layers, all_attn_layers] |
| 160 | + |
| 161 | + # Setup model loader mock |
| 162 | + mock_loader = mock.MagicMock() |
| 163 | + mock_get_loader.return_value = mock_loader |
| 164 | + |
| 165 | + # Setup model mock |
| 166 | + mock_model = mock.MagicMock() |
| 167 | + mock_model_cls.return_value = mock_model |
| 168 | + mock_model.to.return_value = mock_model |
| 169 | + |
| 170 | + # Configure mock to test the attribute sharing path |
| 171 | + if method == "eagle": |
| 172 | + # For eagle, test the lm_head path |
| 173 | + mock_model.load_weights.return_value = { |
| 174 | + "model.embed_tokens.weight": torch.zeros(1) |
| 175 | + } |
| 176 | + else: |
| 177 | + # For eagle3, test the embed_tokens path |
| 178 | + mock_model.load_weights.return_value = {} |
| 179 | + |
| 180 | + # Setup target model with the appropriate attributes |
| 181 | + target_model = mock.MagicMock() |
| 182 | + |
| 183 | + # Create the necessary attributes on the target model |
| 184 | + current_obj = target_model |
| 185 | + for i, attr in enumerate(target_attribute_path): |
| 186 | + if i == len(target_attribute_path) - 1: |
| 187 | + # Set the last attribute in the path to a MagicMock |
| 188 | + setattr(current_obj, attr, mock.MagicMock()) |
| 189 | + else: |
| 190 | + # Create intermediate objects if needed |
| 191 | + setattr(current_obj, attr, mock.MagicMock()) |
| 192 | + current_obj = getattr(current_obj, attr) |
| 193 | + |
| 194 | + # Create proposer using the helper function |
| 195 | + proposer = proposer_helper(k=8) |
| 196 | + |
| 197 | + # Call the method under test |
| 198 | + proposer.load_model(target_model) |
| 199 | + |
| 200 | + # Verify common interactions |
| 201 | + mock_get_loader.assert_called_once() |
| 202 | + mock_model_cls.assert_called_once() |
| 203 | + mock_model.to.assert_called_once() |
| 204 | + mock_model.load_weights.assert_called_once() |
| 205 | + |
| 206 | + # Verify the loader was called with the right config |
| 207 | + mock_get_loader.assert_called_once_with(proposer.vllm_config.load_config) |
| 208 | + |
| 209 | + # Verify the specific attribute sharing based on the method |
| 210 | + if method == "eagle": |
| 211 | + assert proposer.model.lm_head == target_model.lm_head |
| 212 | + else: |
| 213 | + assert proposer.model.model.embed_tokens == \ |
| 214 | + target_model.model.embed_tokens |
| 215 | + |
| 216 | + |
| 217 | +@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) |
| 218 | +def test_propose(num_speculative_tokens): |
| 219 | + # Use GPU device |
| 220 | + device = torch.device('cuda') |
| 221 | + |
| 222 | + # Setup test parameters |
| 223 | + batch_size = 2 |
| 224 | + seq_len_1 = 5 |
| 225 | + seq_len_2 = 3 |
| 226 | + total_tokens = seq_len_1 + seq_len_2 |
| 227 | + vocab_size = 100 |
| 228 | + |
| 229 | + # Create proposer first so we can use its actual hidden_size |
| 230 | + proposer = _create_proposer("eagle", num_speculative_tokens) |
| 231 | + # Get the hidden_size from the proposer to ensure consistency |
| 232 | + hidden_size = proposer.hidden_size |
| 233 | + |
| 234 | + # Helper to create deterministic logits that will produce specific tokens |
| 235 | + def create_deterministic_logits(token_ids): |
| 236 | + logits = torch.full((batch_size, vocab_size), -100.0, device=device) |
| 237 | + for i, token_id in enumerate(token_ids): |
| 238 | + logits[i, token_id] = 100.0 |
| 239 | + return logits |
| 240 | + |
| 241 | + # We mock a model that returns deterministic logits |
| 242 | + # Sequence 1: 42, 43, 44, ... |
| 243 | + # Sequence 2: 60, 61, 62, ... |
| 244 | + base_token_ids = [42, 60] |
| 245 | + |
| 246 | + # Skip loading the model and replace it with a mock directly |
| 247 | + # Create the mock model with deterministic outputs |
| 248 | + model_mock = mock.MagicMock() |
| 249 | + |
| 250 | + # Setup for model forward calls |
| 251 | + forward_returns = [] |
| 252 | + for i in range(num_speculative_tokens): |
| 253 | + if i == 0: |
| 254 | + # First call uses all tokens |
| 255 | + h_logits = torch.zeros(total_tokens, hidden_size, device=device) |
| 256 | + h_states = torch.zeros(total_tokens, hidden_size, device=device) |
| 257 | + else: |
| 258 | + # Subsequent calls use batch_size tokens |
| 259 | + h_logits = torch.zeros(batch_size, hidden_size, device=device) |
| 260 | + h_states = torch.zeros(batch_size, hidden_size, device=device) |
| 261 | + forward_returns.append((h_logits, h_states)) |
| 262 | + |
| 263 | + # For single token case, we only need the first item; |
| 264 | + # for multi-token, we need the sequence |
| 265 | + if num_speculative_tokens == 1: |
| 266 | + model_mock.return_value = forward_returns[0] |
| 267 | + else: |
| 268 | + model_mock.side_effect = forward_returns |
| 269 | + |
| 270 | + # Setup for compute_logits calls |
| 271 | + logits_returns = [] |
| 272 | + for i in range(num_speculative_tokens): |
| 273 | + # For each call, increment the base token IDs |
| 274 | + current_tokens = [base_id + i for base_id in base_token_ids] |
| 275 | + logits_returns.append(create_deterministic_logits(current_tokens)) |
| 276 | + |
| 277 | + if num_speculative_tokens == 1: |
| 278 | + model_mock.compute_logits.return_value = logits_returns[0] |
| 279 | + else: |
| 280 | + model_mock.compute_logits.side_effect = logits_returns |
| 281 | + |
| 282 | + # Assign the mock to the proposer |
| 283 | + proposer.model = model_mock |
| 284 | + |
| 285 | + # Create input tensors |
| 286 | + cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens], |
| 287 | + dtype=torch.int32, |
| 288 | + device=device) |
| 289 | + |
| 290 | + target_token_ids = torch.randint(0, |
| 291 | + vocab_size, (total_tokens, ), |
| 292 | + device=device) |
| 293 | + target_positions = torch.cat([ |
| 294 | + torch.arange(seq_len_1, device=device), |
| 295 | + torch.arange(seq_len_2, device=device) |
| 296 | + ]) |
| 297 | + target_hidden_states = torch.randn(total_tokens, |
| 298 | + hidden_size, |
| 299 | + device=device) |
| 300 | + target_slot_mapping = torch.randint(0, |
| 301 | + 100, (total_tokens, ), |
| 302 | + device=device) |
| 303 | + next_token_ids = torch.randint(0, |
| 304 | + vocab_size, (batch_size, ), |
| 305 | + dtype=torch.int32, |
| 306 | + device=device) |
| 307 | + block_table = torch.randint(0, 10, (batch_size, 10), device=device) |
| 308 | + |
| 309 | + sampling_metadata = mock.MagicMock() |
| 310 | + |
| 311 | + # Call the method under test |
| 312 | + result = proposer.propose(target_token_ids=target_token_ids, |
| 313 | + target_positions=target_positions, |
| 314 | + target_hidden_states=target_hidden_states, |
| 315 | + target_slot_mapping=target_slot_mapping, |
| 316 | + next_token_ids=next_token_ids, |
| 317 | + cu_num_tokens=cu_num_tokens, |
| 318 | + block_table=block_table, |
| 319 | + sampling_metadata=sampling_metadata) |
| 320 | + |
| 321 | + assert result.shape == (batch_size, num_speculative_tokens) |
| 322 | + |
| 323 | + # Create expected tokens based on our token pattern |
| 324 | + if num_speculative_tokens == 1: |
| 325 | + # Example for num_speculative_tokens=1: |
| 326 | + # [[42], [60]] |
| 327 | + expected_tokens = torch.tensor( |
| 328 | + [[base_token_ids[0]], [base_token_ids[1]]], device=device) |
| 329 | + else: |
| 330 | + # Example for num_speculative_tokens=3: |
| 331 | + # [[42, 43, 44], [60, 61, 62]] |
| 332 | + expected_tokens = torch.zeros((batch_size, num_speculative_tokens), |
| 333 | + dtype=torch.int64, |
| 334 | + device=device) |
| 335 | + for i in range(batch_size): |
| 336 | + for j in range(num_speculative_tokens): |
| 337 | + expected_tokens[i, j] = base_token_ids[i] + j |
| 338 | + |
| 339 | + # Verify all tokens match our expectations |
| 340 | + assert torch.equal(result, expected_tokens) |
0 commit comments