Skip to content

Commit c1e4a40

Browse files
authored
[V1][Spec Decode] Support multi-layer eagle draft model (#18030)
Signed-off-by: qizixi <qizixi@meta.com>
1 parent a859320 commit c1e4a40

File tree

3 files changed

+45
-9
lines changed

3 files changed

+45
-9
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ def create_deterministic_logits(token_ids):
246246
# Assign the mock to the proposer
247247
proposer.model = model_mock
248248

249+
# Assign draft attn_layer_names since load_model is not invoked
250+
proposer.attn_layer_names = ["layer.0"]
251+
249252
# Create input tensors
250253
cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens],
251254
dtype=torch.int32,

vllm/v1/spec_decode/eagle.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1313
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
1414
FlashAttentionMetadata)
15+
from vllm.v1.kv_cache_interface import KVCacheConfig
1516
from vllm.v1.sample.metadata import SamplingMetadata
1617
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
1718

@@ -150,6 +151,11 @@ def propose(
150151
else:
151152
raise ValueError(f"Unsupported method: {self.method}")
152153

154+
# At this moment, we assume all eagle layers belong to the same KV
155+
# cache group, thus using the same attention metadata.
156+
per_layer_attn_metadata = {}
157+
for layer_name in self.attn_layer_names:
158+
per_layer_attn_metadata[layer_name] = attn_metadata
153159
if self.use_cuda_graph and \
154160
num_tokens <= self.cudagraph_batch_sizes[-1]:
155161
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
@@ -159,7 +165,7 @@ def propose(
159165
self.positions[:num_tokens] = target_positions
160166
self.hidden_states[:num_tokens] = target_hidden_states
161167

162-
with set_forward_context(attn_metadata,
168+
with set_forward_context(per_layer_attn_metadata,
163169
self.vllm_config,
164170
num_tokens=num_input_tokens):
165171
ret_hidden_states = self.model(
@@ -245,7 +251,7 @@ def propose(
245251
self.hidden_states[:batch_size] = hidden_states
246252

247253
# Run the model.
248-
with set_forward_context(attn_metadata,
254+
with set_forward_context(per_layer_attn_metadata,
249255
self.vllm_config,
250256
num_tokens=input_batch_size):
251257
last_hidden_states, hidden_states = self.model(
@@ -318,8 +324,8 @@ def load_model(self, target_model: nn.Module) -> None:
318324
draft_attn_layer_names = (
319325
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
320326
target_attn_layer_names)
321-
assert len(draft_attn_layer_names) == 1
322-
self.attn_layer_name = next(iter(draft_attn_layer_names))
327+
328+
self.attn_layer_names = list(draft_attn_layer_names)
323329

324330
# share embed_tokens with the target model if needed
325331
if get_pp_group().world_size == 1:
@@ -355,6 +361,25 @@ def dummy_run(
355361
self.hidden_states[:num_tokens],
356362
)
357363

364+
def validate_same_kv_cache_group(self,
365+
kv_cache_config: KVCacheConfig) -> None:
366+
"""
367+
Validate that all eagle layers belong to the same KVCacheGroup.
368+
Need this assumption to ensure all eagle layers can use the
369+
same AttentionMetadata.
370+
May extend to multiple AttentionMetadata in the future.
371+
"""
372+
kv_cache_groups: dict[str, int] = {}
373+
for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
374+
for layer_name in kv_cache_group.layer_names:
375+
kv_cache_groups[layer_name] = id
376+
assert len(
377+
set([
378+
kv_cache_groups[layer_name]
379+
for layer_name in self.attn_layer_names
380+
])
381+
) == 1, "All eagle layers should belong to the same kv cache group"
382+
358383

359384
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
360385
# to sample the draft tokens. We will use this after we find a way to manage

vllm/v1/worker/gpu_model_runner.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,11 +1360,13 @@ def execute_model(
13601360
scheduler_output.num_scheduled_tokens[req_id])
13611361
next_token_id = req_state.get_token_id(seq_len)
13621362
next_token_ids.append(next_token_id)
1363-
next_token_ids = async_tensor_h2d(next_token_ids,
1364-
dtype=torch.int32,
1365-
target_device=self.device,
1366-
pin_memory=True)
1367-
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
1363+
next_token_ids = torch.tensor(next_token_ids,
1364+
dtype=torch.int32,
1365+
device=self.device)
1366+
# At this moment, we assume all eagle layers belong to the same KV
1367+
# cache group, thus using the same attention metadata.
1368+
eagle_attn_metadata = attn_metadata[
1369+
self.drafter.attn_layer_names[0]]
13681370

13691371
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
13701372
if hasattr(eagle_attn_metadata, "block_table"):
@@ -2018,6 +2020,12 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
20182020
# KV cache specs.
20192021
raise ValueError("Unknown KV cache spec type.")
20202022

2023+
if self.speculative_config and self.speculative_config.use_eagle():
2024+
assert isinstance(self.drafter, EagleProposer)
2025+
# validate all draft model layers belong to the same kv cache
2026+
# group
2027+
self.drafter.validate_same_kv_cache_group(kv_cache_config)
2028+
20212029
bind_kv_cache(
20222030
kv_caches,
20232031
self.vllm_config.compilation_config.static_forward_context,

0 commit comments

Comments
 (0)