Skip to content

Commit 09fc2d6

Browse files
committed
Unbreak CFG
1 parent 428f961 commit 09fc2d6

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

exllamav2/generator/sampler.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,6 @@ def sample(
188188

189189
logits = logits.squeeze(1)
190190

191-
# Prepare filter
192-
193-
logit_filter = torch.empty((batch_size, vocab_size), dtype = torch.bool)
194-
ext_c.fast_fill_cpu_ones_bool(logit_filter)
195-
196191
# Sync
197192

198193
if sync:
@@ -206,6 +201,11 @@ def sample(
206201
logits = logits.unsqueeze(0)
207202
batch_size = 1
208203

204+
# Prepare filter
205+
206+
logit_filter = torch.empty((batch_size, vocab_size), dtype = torch.bool)
207+
ext_c.fast_fill_cpu_ones_bool(logit_filter)
208+
209209
# Repetition penalty
210210

211211
if settings.token_repetition_penalty != 1.0 or \

0 commit comments

Comments
 (0)