Skip to content

Commit eaf5426

Browse files
committed
Skip first forward pass when rewinding after banned string
1 parent 3fe6ca8 commit eaf5426

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

examples/banned_strings.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def format_prompt(sp, p):
5656
"Keep in mind that",
5757
"encourage or facilitate harmful",
5858
"I must emphasize",
59+
"However, I must",
5960
"I would like to emphasize",
6061
"Instead of providing",
6162
"Instead of pursuing",

exllamav2/generator/streaming.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
118118
blocked_tokens: list[int]
119119
blocked_position: int
120120
current_blocked_tokens: list[int]
121+
reuse_logits: torch.Tensor | None
121122

122123

123124
def __init__(self, model, cache, tokenizer, draft_model = None, draft_cache = None, num_speculative_tokens = 5):
@@ -372,6 +373,7 @@ def begin_stream_ex(self,
372373
self.blocked_tokens = []
373374
self.blocked_position = -1
374375
self.current_blocked_tokens = []
376+
self.reuse_logits = None
375377

376378

377379
# Convert list of strings to UTF32 format needed, to pass by reference to partial matching function
@@ -512,7 +514,7 @@ def _stream(self, ban_tokens: list[str] | None = None) -> (str, bool, torch.Tens
512514

513515
# Regenerate the last token again, with prefix
514516

515-
healed_token, _, _, _, eos, logits = self._gen_single_token(self.settings, prefix_token = last_token)
517+
healed_token, _, _, _, eos, logits, dev_logits = self._gen_single_token(self.settings, prefix_token = last_token)
516518
new_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:],
517519
decode_special_tokens = self.decode_special_tokens)[0]
518520
self.held_text += new_tail[len(old_tail):]
@@ -534,7 +536,7 @@ def _stream(self, ban_tokens: list[str] | None = None) -> (str, bool, torch.Tens
534536

535537
# Generate a single token and append to the sequence
536538

537-
next_token, next_ptokens, next_pprobs, next_prob, eos, next_logits = self._gen_single_token(self.settings)
539+
next_token, next_ptokens, next_pprobs, next_prob, eos, next_logits, dev_logits = self._gen_single_token(self.settings)
538540

539541
# End immediately if it was a stop token
540542

@@ -572,7 +574,8 @@ def set_checkpoint():
572574
"held_ptokens": self.held_ptokens[:, :-1, :],
573575
"held_pprobs": self.held_pprobs[:, :-1, :],
574576
"held_logits": self.held_logits[:-1, :],
575-
"offending_token": next_token
577+
"offending_token": next_token,
578+
"next_logits": dev_logits
576579
}
577580
self.blocked_position = self.cache.current_seq_len - 1
578581

@@ -587,8 +590,10 @@ def rewind_checkpoint():
587590
self.held_ptokens = cp["held_ptokens"]
588591
self.held_pprobs = cp["held_pprobs"]
589592
self.held_logits = cp["held_logits"]
593+
self.future_logits = None
590594
self.future_tokens = None
591595
self.ban_checkpoint = None
596+
self.reuse_logits = cp["next_logits"]
592597
return cp["offending_token"], off_text
593598

594599
if self.banned_strings_utf32_offsets is not None:
@@ -836,16 +841,24 @@ def _gen_single_token(self, gen_settings, prefix_token = None):
836841
if self.speculative_ngram:
837842

838843
token, ptokens, pprobs, prob, eos, logits = self._gen_single_token_ngram(gen_settings, prefix_token)
844+
dev_logits = None
839845

840846
elif self.draft_model is None:
841847

842-
logits = self.model.forward(
843-
self.sequence_ids[:, -1:],
844-
self.cache,
845-
loras = self.active_loras,
846-
input_mask = self.input_mask,
847-
position_offsets = self.position_offsets
848-
).float().cpu()
848+
if self.reuse_logits is not None:
849+
dev_logits = self.reuse_logits
850+
self.reuse_logits = None
851+
self.cache.current_seq_len += 1
852+
logits = dev_logits.float().cpu()
853+
else:
854+
dev_logits = self.model.forward(
855+
self.sequence_ids[:, -1:],
856+
self.cache,
857+
loras = self.active_loras,
858+
input_mask = self.input_mask,
859+
position_offsets = self.position_offsets
860+
)
861+
logits = dev_logits.float().cpu()
849862

850863
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(
851864
logits,
@@ -862,6 +875,7 @@ def _gen_single_token(self, gen_settings, prefix_token = None):
862875

863876
token, ptokens, pprobs, prob, eos, logits = \
864877
self._gen_single_token_speculative(gen_settings, prefix_token)
878+
dev_logits = None
865879

866880
# Post sampling hook
867881

@@ -889,7 +903,7 @@ def _gen_single_token(self, gen_settings, prefix_token = None):
889903
else:
890904
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
891905

892-
return token, ptokens, pprobs, prob, eos, logits.flatten(1)
906+
return token, ptokens, pprobs, prob, eos, logits.flatten(1), dev_logits
893907

894908

895909
# Speculative decoding with draft model

0 commit comments

Comments
 (0)