@@ -118,6 +118,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
118
118
blocked_tokens : list [int ]
119
119
blocked_position : int
120
120
current_blocked_tokens : list [int ]
121
+ reuse_logits : torch .Tensor | None
121
122
122
123
123
124
def __init__ (self , model , cache , tokenizer , draft_model = None , draft_cache = None , num_speculative_tokens = 5 ):
@@ -372,6 +373,7 @@ def begin_stream_ex(self,
372
373
self .blocked_tokens = []
373
374
self .blocked_position = - 1
374
375
self .current_blocked_tokens = []
376
+ self .reuse_logits = None
375
377
376
378
377
379
# Convert list of strings to UTF32 format needed, to pass by reference to partial matching function
@@ -512,7 +514,7 @@ def _stream(self, ban_tokens: list[str] | None = None) -> (str, bool, torch.Tens
512
514
513
515
# Regenerate the last token again, with prefix
514
516
515
- healed_token , _ , _ , _ , eos , logits = self ._gen_single_token (self .settings , prefix_token = last_token )
517
+ healed_token , _ , _ , _ , eos , logits , dev_logits = self ._gen_single_token (self .settings , prefix_token = last_token )
516
518
new_tail = self .tokenizer .decode (self .sequence_ids [:, - self .tail_decode_tokens :],
517
519
decode_special_tokens = self .decode_special_tokens )[0 ]
518
520
self .held_text += new_tail [len (old_tail ):]
@@ -534,7 +536,7 @@ def _stream(self, ban_tokens: list[str] | None = None) -> (str, bool, torch.Tens
534
536
535
537
# Generate a single token and append to the sequence
536
538
537
- next_token , next_ptokens , next_pprobs , next_prob , eos , next_logits = self ._gen_single_token (self .settings )
539
+ next_token , next_ptokens , next_pprobs , next_prob , eos , next_logits , dev_logits = self ._gen_single_token (self .settings )
538
540
539
541
# End immediately if it was a stop token
540
542
@@ -572,7 +574,8 @@ def set_checkpoint():
572
574
"held_ptokens" : self .held_ptokens [:, :- 1 , :],
573
575
"held_pprobs" : self .held_pprobs [:, :- 1 , :],
574
576
"held_logits" : self .held_logits [:- 1 , :],
575
- "offending_token" : next_token
577
+ "offending_token" : next_token ,
578
+ "next_logits" : dev_logits
576
579
}
577
580
self .blocked_position = self .cache .current_seq_len - 1
578
581
@@ -587,8 +590,10 @@ def rewind_checkpoint():
587
590
self .held_ptokens = cp ["held_ptokens" ]
588
591
self .held_pprobs = cp ["held_pprobs" ]
589
592
self .held_logits = cp ["held_logits" ]
593
+ self .future_logits = None
590
594
self .future_tokens = None
591
595
self .ban_checkpoint = None
596
+ self .reuse_logits = cp ["next_logits" ]
592
597
return cp ["offending_token" ], off_text
593
598
594
599
if self .banned_strings_utf32_offsets is not None :
@@ -836,16 +841,24 @@ def _gen_single_token(self, gen_settings, prefix_token = None):
836
841
if self .speculative_ngram :
837
842
838
843
token , ptokens , pprobs , prob , eos , logits = self ._gen_single_token_ngram (gen_settings , prefix_token )
844
+ dev_logits = None
839
845
840
846
elif self .draft_model is None :
841
847
842
- logits = self .model .forward (
843
- self .sequence_ids [:, - 1 :],
844
- self .cache ,
845
- loras = self .active_loras ,
846
- input_mask = self .input_mask ,
847
- position_offsets = self .position_offsets
848
- ).float ().cpu ()
848
+ if self .reuse_logits is not None :
849
+ dev_logits = self .reuse_logits
850
+ self .reuse_logits = None
851
+ self .cache .current_seq_len += 1
852
+ logits = dev_logits .float ().cpu ()
853
+ else :
854
+ dev_logits = self .model .forward (
855
+ self .sequence_ids [:, - 1 :],
856
+ self .cache ,
857
+ loras = self .active_loras ,
858
+ input_mask = self .input_mask ,
859
+ position_offsets = self .position_offsets
860
+ )
861
+ logits = dev_logits .float ().cpu ()
849
862
850
863
token , ptokens , pprobs , prob , eos = ExLlamaV2Sampler .sample (
851
864
logits ,
@@ -862,6 +875,7 @@ def _gen_single_token(self, gen_settings, prefix_token = None):
862
875
863
876
token , ptokens , pprobs , prob , eos , logits = \
864
877
self ._gen_single_token_speculative (gen_settings , prefix_token )
878
+ dev_logits = None
865
879
866
880
# Post sampling hook
867
881
@@ -889,7 +903,7 @@ def _gen_single_token(self, gen_settings, prefix_token = None):
889
903
else :
890
904
self .sequence_ids = torch .cat ([self .sequence_ids , token ], dim = 1 )
891
905
892
- return token , ptokens , pprobs , prob , eos , logits .flatten (1 )
906
+ return token , ptokens , pprobs , prob , eos , logits .flatten (1 ), dev_logits
893
907
894
908
895
909
# Speculative decoding with draft model
0 commit comments