We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 428f961 commit 09fc2d6Copy full SHA for 09fc2d6
exllamav2/generator/sampler.py
@@ -188,11 +188,6 @@ def sample(
188
189
logits = logits.squeeze(1)
190
191
- # Prepare filter
192
-
193
- logit_filter = torch.empty((batch_size, vocab_size), dtype = torch.bool)
194
- ext_c.fast_fill_cpu_ones_bool(logit_filter)
195
196
# Sync
197
198
if sync:
@@ -206,6 +201,11 @@ def sample(
206
201
logits = logits.unsqueeze(0)
207
202
batch_size = 1
208
203
204
+ # Prepare filter
205
+
+ logit_filter = torch.empty((batch_size, vocab_size), dtype = torch.bool)
+ ext_c.fast_fill_cpu_ones_bool(logit_filter)
209
# Repetition penalty
210
211
if settings.token_repetition_penalty != 1.0 or \
0 commit comments