diff --git a/exllamav2/exllamav2_ext/cpp/sampling.cpp b/exllamav2/exllamav2_ext/cpp/sampling.cpp index 8f9ae7e6..1bf83fc0 100644 --- a/exllamav2/exllamav2_ext/cpp/sampling.cpp +++ b/exllamav2/exllamav2_ext/cpp/sampling.cpp @@ -439,6 +439,66 @@ int sort_descending return pre; } +AVX2_TARGET_OPTIONAL +void logit_threshold_temperature +( + const int num_candidates, + float logit_temp_threshold, + float logit_high_temp, + const int maxlogit, + const float* logits, + const float exponent, + float* temp_probs +) +{ + profile_start("logit_threshold_temperature"); + + float esum = 0.0f; + float static_pmass = 0.0f; + float itemp = 1.0f / std::max(logit_high_temp, 0.01f); + float maxl = logits[maxlogit]; + std::vector<int> above_threshold_indices; + + for (int i = 0; i < num_candidates; i++) + { + float target_logit = logits[i]; + + float l = target_logit - maxl; + if (exponent == 2.0f) + l *= -l; + else if (exponent != 1.0f) + l = -powf(fabs(l), exponent); + + float e = expf(l * itemp); + esum += e; + + if (target_logit >= logit_temp_threshold) + { + temp_probs[i] = e; + above_threshold_indices.push_back(i); + } + else static_pmass += temp_probs[i]; + } + + float isum = (esum >= 0.0f) ? (1.0f / esum) : 1024.0f; + float temp_pmass = 0.0f; + + for (int i : above_threshold_indices) + { + temp_probs[i] *= isum; + temp_pmass += temp_probs[i]; + } + + float adjfactor = (temp_pmass >= 0.0f) ? ((1.0f - static_pmass) / temp_pmass) : 1024.0f; + + for (int i : above_threshold_indices) + { + temp_probs[i] *= adjfactor; + } + + profile_stop(); +} + AVX2_TARGET_OPTIONAL int top_k_cpu ( diff --git a/exllamav2/exllamav2_ext/cpp/sampling.h b/exllamav2/exllamav2_ext/cpp/sampling.h index 814f90b2..33669159 100644 --- a/exllamav2/exllamav2_ext/cpp/sampling.h +++ b/exllamav2/exllamav2_ext/cpp/sampling.h @@ -55,6 +55,17 @@ int sort_descending int max_index ); +void logit_threshold_temperature +( + const int num_candidates, + float logit_temp_threshold, + float logit_high_temp, + const int maxlogit, + const float* logits, + const float exponent, + float* temp_probs +); + int top_k_cpu ( const int num_candidates, diff --git a/exllamav2/exllamav2_ext/ext_sampling.cpp b/exllamav2/exllamav2_ext/ext_sampling.cpp index 2bf312b1..c4704285 100644 --- a/exllamav2/exllamav2_ext/ext_sampling.cpp +++ b/exllamav2/exllamav2_ext/ext_sampling.cpp @@ -109,6 +109,8 @@ std::vector<float> sample_basic float max_temp = 0.0f, float temp_exponent = 1.0f, float smoothing_factor = 0.0f, + float logit_temp_threshold = 0.0f, + float logit_high_temp = 0.0f, float skew = 0.0f ) { @@ -138,8 +140,11 @@ std::vector<float> sample_basic if (temperature < 0.01) { - temperature = 1.0f; - top_k = 1; + if (logit_temp_threshold == 0.0f) + { + temperature = 1.0f; + top_k = 1; + } } for (int i = 0; i < bsz; i++) @@ -164,7 +169,13 @@ std::vector<float> sample_basic for (int j = 0; j < vocab_size; j++) temp_indices[j] = j; int num_candidates = vocab_size; - if (top_k > 0 && top_k < vocab_size) + if (logit_temp_threshold > 0.0f) + { + logit_threshold_temperature(num_candidates, logit_temp_threshold, logit_high_temp, maxlogit, logits_ptr + i * vocab_size, exponent, temp_probs); + normalize_cpu(num_candidates, temp_probs); + } + + if (num_candidates > top_k && top_k > 0 && top_k < vocab_size) { num_candidates = top_k_cpu(num_candidates, temp_probs, temp_indices, top_k, maxlogit); normalize_cpu(num_candidates, temp_probs); diff --git a/exllamav2/exllamav2_ext/ext_sampling.h b/exllamav2/exllamav2_ext/ext_sampling.h index 45d01a52..e56ff73e 100644 --- a/exllamav2/exllamav2_ext/ext_sampling.h +++ b/exllamav2/exllamav2_ext/ext_sampling.h @@ -40,6 +40,8 @@ std::vector<float> sample_basic float max_temp, float temp_exponent, float smoothing_factor, + float logit_temp_threshold, + float logit_high_temp, float skew ); diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index bce25ae8..e8e082f2 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -525,7 +525,7 @@ def set_loras(self, loras: list[ExLlamaV2Lora] | None): self.current_loras = loras else: self.current_loras = [loras] - + def generate( self, @@ -1170,10 +1170,10 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None): for i in range(batch_logits.shape[1]): job_logits = batch_logits[a:b, i:i+1, :] if i == 0 and mt_sample: - next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \ + next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag = \ futures.popleft().result() else: - next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \ + next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag = \ job.receive_logits(job_logits) eos, sampled_token = job.receive_sample( @@ -1183,6 +1183,7 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None): next_k_probs, next_prob, filter_eos, + confidence_flag, results ) @@ -1661,6 +1662,12 @@ def __init__( self.checkpoint = None + # Confidence breaker + + self.confidence_breaker = gen_settings.confidence_breaker + self.confidence_breaker_debug = gen_settings.confidence_breaker_debug + self.confidence_flag_sequence = [False] * self.confidence_breaker + # Measurement self.time_enqueue = None @@ -1762,7 +1769,7 @@ def receive_logits( else: blocked_tokens = self.stop_tokens_list - next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \ + next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag = \ ExLlamaV2Sampler.sample( logits, self.gen_settings, @@ -1777,7 +1784,7 @@ def receive_logits( # sync = True ) - return next_token, next_k_tokens, next_k_probs, next_prob, filter_eos + return next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag def receive_sample( @@ -1788,6 +1795,7 @@ def receive_sample( next_k_probs: torch.Tensor | None, next_prob: torch.Tensor | None, filter_eos: bool | None, + confidence_flag: bool | None, results: list ): page_size = self.generator.page_size @@ -1801,6 +1809,14 @@ def receive_sample( if f.use_background_worker(): self.generator.filter_queue.append(f) + # Update confidence_flag_sequence + + if confidence_flag is not None: + self.confidence_flag_sequence.append(confidence_flag) + # Limit the size of the sequence to prevent it from growing indefinitely + if len(self.confidence_flag_sequence) > self.confidence_breaker + 1: + self.confidence_flag_sequence.pop(0) + # Accept token self.new_tokens += 1 @@ -1986,7 +2002,7 @@ def emit( # End on stop tokens if next_token.item() in self.stop_tokens: - return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item()) + return emit(results, emit_eos = True, emit_held = True, eos_reason = "stop_token", stop_token = next_token.item()) # Stop if we reach max_new_tokens # TODO: Auto-extend option @@ -2011,7 +2027,8 @@ def emit( else: return emit(results) - # Hold text as long as it contains part of a banned string + # Hold text as long as it contains part of a banned string, + # or until we know a confidence breaker will not be triggered def unset_checkpoint(): self.checkpoint = None @@ -2026,6 +2043,7 @@ def set_checkpoint(): "held_k_tokens": self.held_k_tokens.clone(1), "held_k_probs": self.held_k_probs.clone(1), "held_logits": self.held_logits.clone(1), + "flag_sequence": self.confidence_flag_sequence[:-1].copy(), "explored_tokens": [next_token.item()], } else: @@ -2054,28 +2072,59 @@ def rewind_checkpoint(): off_tokens = self.held_tokens.slice(len(self.checkpoint["held_tokens"]), None) off_text = self.held_text[len(self.checkpoint["held_text"]):] self.held_text = self.checkpoint["held_text"] - self.held_token = self.checkpoint["held_tokens"] + self.held_tokens = self.checkpoint["held_tokens"] self.held_probs = self.checkpoint["held_probs"] self.held_k_tokens = self.checkpoint["held_k_tokens"] self.held_k_probs = self.checkpoint["held_k_probs"] self.held_logits = self.checkpoint["held_logits"] + self.confidence_flag_sequence = self.checkpoint["flag_sequence"] self.checkpoint["offset"] = 0 return off_tokens, off_text - if self.banned_strings_utf32_offsets is not None and self.new_tokens > 0: - match = ext_c.partial_strings_match( - np.frombuffer(self.held_text.lower().encode("utf-32-le"), dtype = np.uint8), - self.banned_strings_utf32_offsets, - self.banned_strings_utf32_buffer - ) - if match >= 0: + # Handle banned strings and confidence flags using checkpointing + + if self.new_tokens > 0: + # Check for banned strings + banned_string_match = -1 + if self.banned_strings_utf32_offsets is not None: + banned_string_match = ext_c.partial_strings_match( + np.frombuffer(self.held_text.lower().encode("utf-32-le"), dtype = np.uint8), + self.banned_strings_utf32_offsets, + self.banned_strings_utf32_buffer + ) + + confidence_breaker_match = -1 + if self.confidence_breaker > 0: + # Check for confidence_flag sequence + if confidence_flag is not None: + last_n_flags = self.confidence_flag_sequence[-self.confidence_breaker:] + if not confidence_flag: + confidence_breaker_match = -1 # False flag + elif all(last_n_flags): + confidence_breaker_match = 1 # Match + else: + confidence_breaker_match = -2 # Partial match, wait and see + elif self.confidence_flag_sequence[-1]: + confidence_breaker_match = -2 # Pause current sequence without resetting partial match + else: + confidence_breaker_match = -1 # Treat None as False flag, following previous False flag + + if confidence_breaker_match >= 0: # Match confidence breaker + set_checkpoint() + if self.confidence_breaker_debug: + print(f'[Confidence breaker activated on text: "{self.held_text}"]', flush=True) + offending_tokens, offending_text = rewind_checkpoint() + return emit(results, suppressed_text = offending_text, suppressed_tokens = offending_tokens) + elif banned_string_match >= 0: set_checkpoint() offending_tokens, offending_text = rewind_checkpoint() - return emit(results, emit_held = True, suppressed_text = offending_text, suppressed_tokens = offending_tokens) - elif match == -2: + return emit(results, suppressed_text = offending_text, suppressed_tokens = offending_tokens) + elif banned_string_match == -2 or confidence_breaker_match == -2: # Partial match set_checkpoint() return emit(results) - else: + else: # Reset and permit text passthrough + if len(self.full_completion) > 0: + set_checkpoint() unset_checkpoint() # End on stop strings diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index dcad0ee1..de852568 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -73,6 +73,14 @@ class Settings: temperature_last: bool = False + logit_temp_threshold: float = 0.0 + logit_high_temp: float = 0.0 + + confidence_breaker: int = 0 + confidence_breaker_debug: bool = False + cb_mid_threshold: float = 0.0 + cb_high_threshold: float = 0.0 + mirostat: bool = False mirostat_tau: float = 1.5 mirostat_eta: float = 0.1 @@ -420,6 +428,7 @@ def prep_logit_filter(lf): # Temporarily ban individual tokens if blocked_tokens: + saved_logits = logits[:, :, blocked_tokens].clone() logits[:, :, blocked_tokens] = -1e30 # Token bias @@ -552,9 +561,33 @@ def prep_logit_filter(lf): settings.max_temp, settings.temp_exponent, settings.smoothing_factor, + settings.logit_temp_threshold, + settings.logit_high_temp, settings.skew ) + if settings.confidence_breaker > 0: + if blocked_tokens and 'saved_logits' in locals(): + # Restore the saved logits values for the blocked tokens + logits[:, :, blocked_tokens] = saved_logits + + squeezed_logits = logits.squeeze(0).squeeze(0) + probs = F.softmax(squeezed_logits, dim=-1) + token_prob = probs[output_tokens] + token_logit = squeezed_logits[output_tokens] + if settings.cb_mid_threshold <= 1.0: + confidence_flag = (token_prob >= settings.cb_mid_threshold).item() + else: + confidence_flag = (token_logit >= settings.cb_mid_threshold).item() + if settings.cb_high_threshold <= 1.0: + if (token_prob > settings.cb_high_threshold).item(): + confidence_flag = None + else: + if (token_logit > settings.cb_high_threshold).item(): + confidence_flag = None + else: + confidence_flag = None + if settings.mirostat: settings.mirostat_mu = m # Stop condition from filters @@ -563,4 +596,4 @@ def prep_logit_filter(lf): if len(filters) > 0 and end_tokens is not None and output_tokens[0].item() in end_tokens: end_filter = True - return output_tokens, output_ktokens, output_kprobs, output_probs, end_filter + return output_tokens, output_ktokens, output_kprobs, output_probs, end_filter, confidence_flag