Skip to content

Commit ce08f16

Browse files
committed
consistent shape between logits and generated_ids
1 parent 4aa4ebd commit ce08f16

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

exllamav2/generator/sampler.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,12 @@ def sample(
364364
# Apply logits processor
365365

366366
if settings.logits_processor:
367-
generated_ids = sequence_ids[:, input_ids.shape[1]:]
368-
logits = settings.logits_processor(generated_ids, logits)
367+
generated_ids = sequence_ids[:, input_ids.shape[-1]:]
368+
# normalize to 2d
369+
logits_2d = logits.view(-1, logits.shape[-1])
370+
generated_ids_2d = generated_ids.view(logits_2d.shape[0], generated_ids.shape[-1])
371+
# process logits and convert back to original logits shape
372+
logits = settings.logits_processor(generated_ids_2d, logits_2d).view(logits.shape)
369373

370374
# Prepare filter
371375

0 commit comments

Comments
 (0)