12
12
from vllm .model_executor .models .llama_eagle3 import Eagle3LlamaForCausalLM
13
13
from vllm .v1 .attention .backends .flash_attn import (CommonAttentionMetadata ,
14
14
FlashAttentionMetadata )
15
+ from vllm .v1 .kv_cache_interface import KVCacheConfig
15
16
from vllm .v1 .sample .metadata import SamplingMetadata
16
17
from vllm .v1 .spec_decode .utils import prepare_eagle_input_kernel
17
18
@@ -150,6 +151,11 @@ def propose(
150
151
else :
151
152
raise ValueError (f"Unsupported method: { self .method } " )
152
153
154
+ # At this moment, we assume all eagle layers belong to the same KV
155
+ # cache group, thus using the same attention metadata.
156
+ per_layer_attn_metadata = {}
157
+ for layer_name in self .attn_layer_names :
158
+ per_layer_attn_metadata [layer_name ] = attn_metadata
153
159
if self .use_cuda_graph and \
154
160
num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
155
161
num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
@@ -159,7 +165,7 @@ def propose(
159
165
self .positions [:num_tokens ] = target_positions
160
166
self .hidden_states [:num_tokens ] = target_hidden_states
161
167
162
- with set_forward_context (attn_metadata ,
168
+ with set_forward_context (per_layer_attn_metadata ,
163
169
self .vllm_config ,
164
170
num_tokens = num_input_tokens ):
165
171
ret_hidden_states = self .model (
@@ -245,7 +251,7 @@ def propose(
245
251
self .hidden_states [:batch_size ] = hidden_states
246
252
247
253
# Run the model.
248
- with set_forward_context (attn_metadata ,
254
+ with set_forward_context (per_layer_attn_metadata ,
249
255
self .vllm_config ,
250
256
num_tokens = input_batch_size ):
251
257
last_hidden_states , hidden_states = self .model (
@@ -318,8 +324,8 @@ def load_model(self, target_model: nn.Module) -> None:
318
324
draft_attn_layer_names = (
319
325
get_layers_from_vllm_config (self .vllm_config , Attention ).keys () -
320
326
target_attn_layer_names )
321
- assert len ( draft_attn_layer_names ) == 1
322
- self .attn_layer_name = next ( iter ( draft_attn_layer_names ) )
327
+
328
+ self .attn_layer_names = list ( draft_attn_layer_names )
323
329
324
330
# share embed_tokens with the target model if needed
325
331
if get_pp_group ().world_size == 1 :
@@ -355,6 +361,25 @@ def dummy_run(
355
361
self .hidden_states [:num_tokens ],
356
362
)
357
363
364
+ def validate_same_kv_cache_group (self ,
365
+ kv_cache_config : KVCacheConfig ) -> None :
366
+ """
367
+ Validate that all eagle layers belong to the same KVCacheGroup.
368
+ Need this assumption to ensure all eagle layers can use the
369
+ same AttentionMetadata.
370
+ May extend to multiple AttentionMetadata in the future.
371
+ """
372
+ kv_cache_groups : dict [str , int ] = {}
373
+ for id , kv_cache_group in enumerate (kv_cache_config .kv_cache_groups ):
374
+ for layer_name in kv_cache_group .layer_names :
375
+ kv_cache_groups [layer_name ] = id
376
+ assert len (
377
+ set ([
378
+ kv_cache_groups [layer_name ]
379
+ for layer_name in self .attn_layer_names
380
+ ])
381
+ ) == 1 , "All eagle layers should belong to the same kv cache group"
382
+
358
383
359
384
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
360
385
# to sample the draft tokens. We will use this after we find a way to manage
0 commit comments