We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4aa4ebd commit ce08f16Copy full SHA for ce08f16
exllamav2/generator/sampler.py
@@ -364,8 +364,12 @@ def sample(
364
# Apply logits processor
365
366
if settings.logits_processor:
367
- generated_ids = sequence_ids[:, input_ids.shape[1]:]
368
- logits = settings.logits_processor(generated_ids, logits)
+ generated_ids = sequence_ids[:, input_ids.shape[-1]:]
+ # normalize to 2d
369
+ logits_2d = logits.view(-1, logits.shape[-1])
370
+ generated_ids_2d = generated_ids.view(logits_2d.shape[0], generated_ids.shape[-1])
371
+ # process logits and convert back to original logits shape
372
+ logits = settings.logits_processor(generated_ids_2d, logits_2d).view(logits.shape)
373
374
# Prepare filter
375
0 commit comments