File tree 3 files changed +14
-1
lines changed
3 files changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -179,7 +179,7 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
179
179
180
180
181
181
# TRACKING: https://github.com/vllm-project/vllm/issues/18166
182
- @pytest .mark .skip (reason = "RE-ENABLE: Failing on main." )
182
+ # @pytest.mark.skip(reason="RE-ENABLE: Failing on main.")
183
183
@pytest .mark .parametrize (
184
184
"common_llm_kwargs" ,
185
185
[{
Original file line number Diff line number Diff line change @@ -145,6 +145,17 @@ def forward(
145
145
if inputs_embeds is None :
146
146
inputs_embeds = self .get_input_embeddings (input_ids )
147
147
148
+ # Handle both empty previous_hidden_states
149
+ # and mismatched batch size
150
+ batch_size = inputs_embeds .size (0 )
151
+ if previous_hidden_states .size (0 ) == 0 or \
152
+ previous_hidden_states .size (0 ) != batch_size :
153
+ hidden_dim = self .config .model .hidden_size
154
+ device = inputs_embeds .device
155
+ # Create zero tensor with matching batch size
156
+ previous_hidden_states = \
157
+ torch .zeros (batch_size , hidden_dim , device = device )
158
+
148
159
if self .add_para_norm :
149
160
inputs_embeds = torch .cat ([
150
161
self .enorm (inputs_embeds ),
Original file line number Diff line number Diff line change @@ -1330,6 +1330,8 @@ def prune(self,
1330
1330
# may be "paused" then "resumed" later. This should only prune sequences
1331
1331
# which are confirmed to be aborted.
1332
1332
seq_ids = get_all_seq_ids (seq_group_metadata_list )
1333
+ # Only keep sequence IDs that exist in self._seq_ids
1334
+ seq_ids = [seq_id for seq_id in seq_ids if seq_id in self ._seq_ids ]
1333
1335
if seq_ids != self ._seq_ids :
1334
1336
# Batch contents changed - prune removed sequences.
1335
1337
index = [self ._seq_ids .index (seq_id ) for seq_id in seq_ids ]
You can’t perform that action at this time.
0 commit comments