Skip to content

Commit 2757f4a

Browse files
committed
fix
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
1 parent 09f106a commit 2757f4a

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

tests/spec_decode/e2e/test_eagle_correctness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
179179

180180

181181
# TRACKING: https://github.com/vllm-project/vllm/issues/18166
182-
@pytest.mark.skip(reason="RE-ENABLE: Failing on main.")
182+
# @pytest.mark.skip(reason="RE-ENABLE: Failing on main.")
183183
@pytest.mark.parametrize(
184184
"common_llm_kwargs",
185185
[{

vllm/model_executor/models/eagle.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,17 @@ def forward(
145145
if inputs_embeds is None:
146146
inputs_embeds = self.get_input_embeddings(input_ids)
147147

148+
# Handle both empty previous_hidden_states
149+
# and mismatched batch size
150+
batch_size = inputs_embeds.size(0)
151+
if previous_hidden_states.size(0) == 0 or \
152+
previous_hidden_states.size(0) != batch_size:
153+
hidden_dim = self.config.model.hidden_size
154+
device = inputs_embeds.device
155+
# Create zero tensor with matching batch size
156+
previous_hidden_states = \
157+
torch.zeros(batch_size, hidden_dim, device=device)
158+
148159
if self.add_para_norm:
149160
inputs_embeds = torch.cat([
150161
self.enorm(inputs_embeds),

vllm/sequence.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,8 @@ def prune(self,
13301330
# may be "paused" then "resumed" later. This should only prune sequences
13311331
# which are confirmed to be aborted.
13321332
seq_ids = get_all_seq_ids(seq_group_metadata_list)
1333+
# Only keep sequence IDs that exist in self._seq_ids
1334+
seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids]
13331335
if seq_ids != self._seq_ids:
13341336
# Batch contents changed - prune removed sequences.
13351337
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]

0 commit comments

Comments
 (0)