Skip to content

Commit 212d4cd

Browse files
committed
consistent shape between logits and generated_ids
1 parent 4aa4ebd commit 212d4cd

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

exllamav2/generator/sampler.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,9 @@ def sample(
364364
# Apply logits processor
365365

366366
if settings.logits_processor:
367-
generated_ids = sequence_ids[:, input_ids.shape[1]:]
367+
generated_ids = sequence_ids[:, input_ids.shape[1]:].view(
368+
logits.shape[:-1] + sequence_ids.shape[-1:] # ensure consistent batch dimensions
369+
)
368370
logits = settings.logits_processor(generated_ids, logits)
369371

370372
# Prepare filter

0 commit comments

Comments
 (0)