From 34c33c3952742707e1fdbe86735f6e5c5ae6ca06 Mon Sep 17 00:00:00 2001 From: anchortense Date: Thu, 17 Oct 2024 18:30:51 +1100 Subject: [PATCH 1/4] Implementation of logit threshold sampler and confidence breaker using c++ extensions --- exllamav2/exllamav2_ext/cpp/sampling.cpp | 67 +++++++++++++++++++ exllamav2/exllamav2_ext/cpp/sampling.h | 12 ++++ exllamav2/exllamav2_ext/ext_sampling.cpp | 19 +++++- exllamav2/exllamav2_ext/ext_sampling.h | 2 + exllamav2/generator/dynamic.py | 83 +++++++++++++++++++----- exllamav2/generator/sampler.py | 75 ++++++++++++++++++++- 6 files changed, 239 insertions(+), 19 deletions(-) diff --git a/exllamav2/exllamav2_ext/cpp/sampling.cpp b/exllamav2/exllamav2_ext/cpp/sampling.cpp index 8f9ae7e6..7a677244 100644 --- a/exllamav2/exllamav2_ext/cpp/sampling.cpp +++ b/exllamav2/exllamav2_ext/cpp/sampling.cpp @@ -439,6 +439,73 @@ int sort_descending return pre; } +AVX2_TARGET_OPTIONAL +int logit_threshold_restore +( + float logit_min_threshold, + float logit_temp_threshold, + const int maxlogit, + const int vocab_size, + const float* logits, + const float exponent, + float* temp_probs, + int* temp_indices +) +{ + profile_start("logit_threshold_restore"); + + float esum = 0.0f; + int n = 0; + float maxl = logits[maxlogit]; + float effective_min = std::min(maxl, logit_min_threshold); + + for (int i = 0; i < vocab_size; i++) + { + float target_logit = logits[i]; + if (target_logit < effective_min) continue; + float l = target_logit - maxl; + if (exponent == 2.0f) + l *= -l; + else if (exponent != 1.0f) + l = -powf(fabs(l), exponent); + float e = expf(l); + esum += e; + if (target_logit < logit_temp_threshold) + temp_probs[i] = e; + } + + float isum = 1.0f / esum; + float diffsum = 0.0f; + + for (int i = 0; i < vocab_size; i++) + { + if (logits[i] < effective_min) continue; + if (logits[i] < logit_temp_threshold) + { + temp_probs[i] *= isum; + diffsum += temp_probs[i]; + n++; + } + } + + float adjfactor = 1.0f - diffsum; + + for (int i = 0; i < vocab_size; i++) + { + if (logits[i] >= logit_temp_threshold) + { + temp_probs[i] *= adjfactor; + n++; + } + } + + sort_descending(vocab_size, temp_probs, temp_indices, n); + + profile_stop(); + return n; + +} + AVX2_TARGET_OPTIONAL int top_k_cpu ( diff --git a/exllamav2/exllamav2_ext/cpp/sampling.h b/exllamav2/exllamav2_ext/cpp/sampling.h index 814f90b2..316aca5c 100644 --- a/exllamav2/exllamav2_ext/cpp/sampling.h +++ b/exllamav2/exllamav2_ext/cpp/sampling.h @@ -55,6 +55,18 @@ int sort_descending int max_index ); +int logit_threshold_restore +( + float logit_min_threshold, + float logit_temp_threshold, + const int maxlogit, + const int vocab_size, + const float* logits, + const float exponent, + float* temp_probs, + int* temp_indices +); + int top_k_cpu ( const int num_candidates, diff --git a/exllamav2/exllamav2_ext/ext_sampling.cpp b/exllamav2/exllamav2_ext/ext_sampling.cpp index 2bf312b1..4a1f5c11 100644 --- a/exllamav2/exllamav2_ext/ext_sampling.cpp +++ b/exllamav2/exllamav2_ext/ext_sampling.cpp @@ -85,6 +85,8 @@ std::vector sample_basic ( torch::Tensor logits, // shape [bsz, 1, vocab_size] float temperature, + float logit_temp_threshold, + float logit_min_threshold, int top_k, float top_p, float top_a, @@ -164,7 +166,22 @@ std::vector sample_basic for (int j = 0; j < vocab_size; j++) temp_indices[j] = j; int num_candidates = vocab_size; - if (top_k > 0 && top_k < vocab_size) + if ((logit_temp_threshold > logit_min_threshold) && logit_min_threshold > 0.0f) + { + num_candidates = logit_threshold_restore + ( + logit_min_threshold, + logit_temp_threshold, + maxlogit, + vocab_size, + logits_ptr + i * vocab_size, + exponent, + temp_probs, + temp_indices + ); + } + + if (num_candidates > top_k && top_k > 0 && top_k < vocab_size) { num_candidates = top_k_cpu(num_candidates, temp_probs, temp_indices, top_k, maxlogit); normalize_cpu(num_candidates, temp_probs); diff --git a/exllamav2/exllamav2_ext/ext_sampling.h b/exllamav2/exllamav2_ext/ext_sampling.h index 45d01a52..51c4ed4c 100644 --- a/exllamav2/exllamav2_ext/ext_sampling.h +++ b/exllamav2/exllamav2_ext/ext_sampling.h @@ -16,6 +16,8 @@ std::vector sample_basic ( torch::Tensor logits, // shape [bsz, vocab_size] float temperature, + float logit_temp_threshold, + float logit_min_threshold, int top_k, float top_p, float top_a, diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index bce25ae8..8cd62898 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -525,7 +525,7 @@ def set_loras(self, loras: list[ExLlamaV2Lora] | None): self.current_loras = loras else: self.current_loras = [loras] - + def generate( self, @@ -1170,10 +1170,10 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None): for i in range(batch_logits.shape[1]): job_logits = batch_logits[a:b, i:i+1, :] if i == 0 and mt_sample: - next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \ + next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag = \ futures.popleft().result() else: - next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \ + next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag = \ job.receive_logits(job_logits) eos, sampled_token = job.receive_sample( @@ -1183,6 +1183,7 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None): next_k_probs, next_prob, filter_eos, + confidence_flag, results ) @@ -1661,6 +1662,12 @@ def __init__( self.checkpoint = None + # Confidence breaker + + self.confidence_breaker = gen_settings.confidence_breaker + self.confidence_breaker_debug = gen_settings.confidence_breaker_debug + self.confidence_flag_sequence = [False] * self.confidence_breaker + # Measurement self.time_enqueue = None @@ -1762,7 +1769,7 @@ def receive_logits( else: blocked_tokens = self.stop_tokens_list - next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \ + next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag = \ ExLlamaV2Sampler.sample( logits, self.gen_settings, @@ -1777,7 +1784,7 @@ def receive_logits( # sync = True ) - return next_token, next_k_tokens, next_k_probs, next_prob, filter_eos + return next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag def receive_sample( @@ -1788,6 +1795,7 @@ def receive_sample( next_k_probs: torch.Tensor | None, next_prob: torch.Tensor | None, filter_eos: bool | None, + confidence_flag: bool | None, results: list ): page_size = self.generator.page_size @@ -1801,6 +1809,14 @@ def receive_sample( if f.use_background_worker(): self.generator.filter_queue.append(f) + # Update confidence_flag_sequence + + if confidence_flag is not None: + self.confidence_flag_sequence.append(confidence_flag) + # Limit the size of the sequence to prevent it from growing indefinitely + if len(self.confidence_flag_sequence) > self.confidence_breaker + 1: + self.confidence_flag_sequence.pop(0) + # Accept token self.new_tokens += 1 @@ -2011,7 +2027,8 @@ def emit( else: return emit(results) - # Hold text as long as it contains part of a banned string + # Hold text as long as it contains part of a banned string, + # or until we know a confidence breaker will not be triggered def unset_checkpoint(): self.checkpoint = None @@ -2026,6 +2043,7 @@ def set_checkpoint(): "held_k_tokens": self.held_k_tokens.clone(1), "held_k_probs": self.held_k_probs.clone(1), "held_logits": self.held_logits.clone(1), + "flag_sequence": self.confidence_flag_sequence[:-1].copy(), "explored_tokens": [next_token.item()], } else: @@ -2054,28 +2072,59 @@ def rewind_checkpoint(): off_tokens = self.held_tokens.slice(len(self.checkpoint["held_tokens"]), None) off_text = self.held_text[len(self.checkpoint["held_text"]):] self.held_text = self.checkpoint["held_text"] - self.held_token = self.checkpoint["held_tokens"] + self.held_tokens = self.checkpoint["held_tokens"] self.held_probs = self.checkpoint["held_probs"] self.held_k_tokens = self.checkpoint["held_k_tokens"] self.held_k_probs = self.checkpoint["held_k_probs"] self.held_logits = self.checkpoint["held_logits"] + self.confidence_flag_sequence = self.checkpoint["flag_sequence"] self.checkpoint["offset"] = 0 return off_tokens, off_text - if self.banned_strings_utf32_offsets is not None and self.new_tokens > 0: - match = ext_c.partial_strings_match( - np.frombuffer(self.held_text.lower().encode("utf-32-le"), dtype = np.uint8), - self.banned_strings_utf32_offsets, - self.banned_strings_utf32_buffer - ) - if match >= 0: + # Handle banned strings and confidence flags using checkpointing + + if self.new_tokens > 0: + # Check for banned strings + banned_string_match = -1 + if self.banned_strings_utf32_offsets is not None: + banned_string_match = ext_c.partial_strings_match( + np.frombuffer(self.held_text.lower().encode("utf-32-le"), dtype = np.uint8), + self.banned_strings_utf32_offsets, + self.banned_strings_utf32_buffer + ) + + confidence_breaker_match = -1 + if self.confidence_breaker > 0: + # Check for confidence_flag sequence + if confidence_flag is not None: + last_n_flags = self.confidence_flag_sequence[-self.confidence_breaker:] + if not confidence_flag: + confidence_breaker_match = -1 # False flag + elif all(last_n_flags): + confidence_breaker_match = 1 # Match + else: + confidence_breaker_match = -2 # Partial match, wait and see + elif self.confidence_flag_sequence[-1]: + confidence_breaker_match = -2 # Pause current sequence without resetting partial match + else: + confidence_breaker_match = -1 # Treat None as False flag, following previous False flag + + if confidence_breaker_match >= 0: # Match confidence breaker + set_checkpoint() + if self.confidence_breaker_debug: + print(f'[Confidence breaker activated on text: "{self.held_text}"]', flush=True) + offending_tokens, offending_text = rewind_checkpoint() + return emit(results, suppressed_text = offending_text, suppressed_tokens = offending_tokens) + elif banned_string_match >= 0: set_checkpoint() offending_tokens, offending_text = rewind_checkpoint() - return emit(results, emit_held = True, suppressed_text = offending_text, suppressed_tokens = offending_tokens) - elif match == -2: + return emit(results, suppressed_text = offending_text, suppressed_tokens = offending_tokens) + elif banned_string_match == -2 or confidence_breaker_match == -2: # Partial match set_checkpoint() return emit(results) - else: + else: # Reset and permit text passthrough + if len(self.full_completion) > 0: + set_checkpoint() unset_checkpoint() # End on stop strings diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index dcad0ee1..7cf7ba6b 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -73,6 +73,14 @@ class Settings: temperature_last: bool = False + logit_threshold_stats: bool = False + logit_temp_threshold: float = 0.0 + logit_min_threshold: float = 0.0 + + confidence_breaker: int = 0 + cb_mid_threshold: float = 0.0 + cb_high_threshold: float = 0.0 + mirostat: bool = False mirostat_tau: float = 1.5 mirostat_eta: float = 0.1 @@ -420,6 +428,7 @@ def prep_logit_filter(lf): # Temporarily ban individual tokens if blocked_tokens: + saved_logits = logits[:, :, blocked_tokens].clone() logits[:, :, blocked_tokens] = -1e30 # Token bias @@ -525,9 +534,18 @@ def prep_logit_filter(lf): output_ktokens = torch.empty((batch_size, 1, return_top_tokens), dtype = torch.long) output_kprobs = torch.empty((batch_size, 1, return_top_tokens), dtype = torch.float) + if settings.logit_temp_threshold > 0.0 or settings.logit_min_threshold > 0.0: + logit_filter = prep_logit_filter(logit_filter) + effective_filter = max(settings.logit_temp_threshold, settings.logit_min_threshold) + logit_filter[logits.squeeze(1) < effective_filter] = False + if not torch.any(logit_filter): + logit_filter.view(-1)[torch.argmax(logits.squeeze(1))] = True + m = ext_c.sample_basic( logits, 1.0 if settings.temperature_last else settings.temperature, + settings.logit_temp_threshold, + settings.logit_min_threshold, settings.top_k, settings.top_p, settings.top_a, @@ -555,6 +573,61 @@ def prep_logit_filter(lf): settings.skew ) + if settings.confidence_breaker > 0: + if blocked_tokens and 'saved_logits' in locals(): + # Restore the saved logits values for the blocked tokens + logits[:, :, blocked_tokens] = saved_logits + + squeezed_logits = logits.squeeze(0).squeeze(0) + probs = F.softmax(squeezed_logits, dim=-1) + token_prob = probs[output_tokens] + token_logit = squeezed_logits[output_tokens] + if settings.cb_mid_threshold <= 1.0: + confidence_flag = (token_prob >= settings.cb_mid_threshold).item() + else: + confidence_flag = (token_logit >= settings.cb_mid_threshold).item() + if settings.cb_high_threshold <= 1.0: + if (token_prob > settings.cb_high_threshold).item(): + confidence_flag = None + else: + if (token_logit > settings.cb_high_threshold).item(): + confidence_flag = None + else: + confidence_flag = None + + if settings.logit_threshold_stats: + selected_token = output_tokens + batch_logits_squeezed = logits[0, 0, :] + token_logit = batch_logits_squeezed[selected_token] + min_logit_threshold = min(torch.max(batch_logits_squeezed).item(), + settings.logit_min_threshold if settings.logit_min_threshold > 0 + else settings.logit_temp_threshold) + filtered_indices_mask = batch_logits_squeezed >= min_logit_threshold + filtered_logits = batch_logits_squeezed[filtered_indices_mask] + probs = F.softmax(batch_logits_squeezed, dim=-1) + filtered_probs = probs[filtered_indices_mask] + + # Calculate the statistics for filtered_logits + min_filtered = filtered_logits.min().item() if len(filtered_logits) > 0 else float('nan') + mean_filtered = filtered_logits.mean().item() if len(filtered_logits) > 0 else float('nan') + max_filtered = filtered_logits.max().item() if len(filtered_logits) > 0 else float('nan') + std_filtered = filtered_logits.std().item() if len(filtered_logits) > 0 else float('nan') + min_p_equivalent = filtered_probs[filtered_logits.argmin()].item() + + debug_string = ( + f"total logits: {batch_logits_squeezed.size(0):<7} " + f"filtered to: {filtered_logits.size(0):<4} " + f"min: {min_filtered:>5.2f} " + f"mean: {mean_filtered:>5.2f} " + f"max: {max_filtered:>5.2f} " + f"std: {std_filtered:>5.2f} " + f"selected logit: {token_logit.item():>5.2f} " + f"selected token: {selected_token.item():<7} " + f"min_p: {min_p_equivalent:>6.5f}" + ) + print(debug_string, flush=True) + + if settings.mirostat: settings.mirostat_mu = m # Stop condition from filters @@ -563,4 +636,4 @@ def prep_logit_filter(lf): if len(filters) > 0 and end_tokens is not None and output_tokens[0].item() in end_tokens: end_filter = True - return output_tokens, output_ktokens, output_kprobs, output_probs, end_filter + return output_tokens, output_ktokens, output_kprobs, output_probs, end_filter, confidence_flag From 50a3ef821466c568743d86f2195518d5ff5571e6 Mon Sep 17 00:00:00 2001 From: anchortense Date: Sat, 19 Oct 2024 09:41:27 +1100 Subject: [PATCH 2/4] Update sampler.py Added confidence_breaker_debug to Sampler class --- exllamav2/generator/sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index 7cf7ba6b..a4e2684c 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -78,6 +78,7 @@ class Settings: logit_min_threshold: float = 0.0 confidence_breaker: int = 0 + confidence_breaker_debug: bool = False cb_mid_threshold: float = 0.0 cb_high_threshold: float = 0.0 From a6df6635503e726ca9146ccafbdf4c8b94b5cd53 Mon Sep 17 00:00:00 2001 From: anchortense Date: Wed, 23 Oct 2024 17:15:28 +1100 Subject: [PATCH 3/4] Simplified LTS implementation --- exllamav2/exllamav2_ext/cpp/sampling.cpp | 33 +++++++++-------- exllamav2/exllamav2_ext/cpp/sampling.h | 4 +- exllamav2/exllamav2_ext/ext_sampling.cpp | 26 +++++-------- exllamav2/exllamav2_ext/ext_sampling.h | 4 +- exllamav2/generator/sampler.py | 47 ++---------------------- 5 files changed, 34 insertions(+), 80 deletions(-) diff --git a/exllamav2/exllamav2_ext/cpp/sampling.cpp b/exllamav2/exllamav2_ext/cpp/sampling.cpp index 7a677244..ca308706 100644 --- a/exllamav2/exllamav2_ext/cpp/sampling.cpp +++ b/exllamav2/exllamav2_ext/cpp/sampling.cpp @@ -440,10 +440,10 @@ int sort_descending } AVX2_TARGET_OPTIONAL -int logit_threshold_restore +int logit_threshold_temperature ( - float logit_min_threshold, float logit_temp_threshold, + float logit_high_temp, const int maxlogit, const int vocab_size, const float* logits, @@ -452,51 +452,52 @@ int logit_threshold_restore int* temp_indices ) { - profile_start("logit_threshold_restore"); + profile_start("logit_threshold_temperature"); float esum = 0.0f; + float static_pmass = 0.0f; + float itemp = 1.0f / std::max(logit_high_temp, 0.01f); int n = 0; float maxl = logits[maxlogit]; - float effective_min = std::min(maxl, logit_min_threshold); for (int i = 0; i < vocab_size; i++) { float target_logit = logits[i]; - if (target_logit < effective_min) continue; + float l = target_logit - maxl; if (exponent == 2.0f) l *= -l; else if (exponent != 1.0f) l = -powf(fabs(l), exponent); - float e = expf(l); + + float e = expf(l * itemp); esum += e; - if (target_logit < logit_temp_threshold) + + if (target_logit >= logit_temp_threshold) temp_probs[i] = e; + else static_pmass += temp_probs[i]; + + n++; } float isum = 1.0f / esum; - float diffsum = 0.0f; + float temp_pmass = 0.0f; for (int i = 0; i < vocab_size; i++) { - if (logits[i] < effective_min) continue; - if (logits[i] < logit_temp_threshold) + if (logits[i] >= logit_temp_threshold) { temp_probs[i] *= isum; - diffsum += temp_probs[i]; - n++; + temp_pmass += temp_probs[i]; } } - float adjfactor = 1.0f - diffsum; + float adjfactor = (1.0f - static_pmass) / temp_pmass; for (int i = 0; i < vocab_size; i++) { if (logits[i] >= logit_temp_threshold) - { temp_probs[i] *= adjfactor; - n++; - } } sort_descending(vocab_size, temp_probs, temp_indices, n); diff --git a/exllamav2/exllamav2_ext/cpp/sampling.h b/exllamav2/exllamav2_ext/cpp/sampling.h index 316aca5c..e10675b0 100644 --- a/exllamav2/exllamav2_ext/cpp/sampling.h +++ b/exllamav2/exllamav2_ext/cpp/sampling.h @@ -55,10 +55,10 @@ int sort_descending int max_index ); -int logit_threshold_restore +int logit_threshold_temperature ( - float logit_min_threshold, float logit_temp_threshold, + float logit_high_temp, const int maxlogit, const int vocab_size, const float* logits, diff --git a/exllamav2/exllamav2_ext/ext_sampling.cpp b/exllamav2/exllamav2_ext/ext_sampling.cpp index 4a1f5c11..64cef55b 100644 --- a/exllamav2/exllamav2_ext/ext_sampling.cpp +++ b/exllamav2/exllamav2_ext/ext_sampling.cpp @@ -85,8 +85,6 @@ std::vector sample_basic ( torch::Tensor logits, // shape [bsz, 1, vocab_size] float temperature, - float logit_temp_threshold, - float logit_min_threshold, int top_k, float top_p, float top_a, @@ -111,6 +109,8 @@ std::vector sample_basic float max_temp = 0.0f, float temp_exponent = 1.0f, float smoothing_factor = 0.0f, + float logit_temp_threshold = 0.0f, + float logit_high_temp = 0.0f, float skew = 0.0f ) { @@ -140,8 +140,11 @@ std::vector sample_basic if (temperature < 0.01) { - temperature = 1.0f; - top_k = 1; + if (logit_temp_threshold == 0.0f) + { + temperature = 1.0f; + top_k = 1; + } } for (int i = 0; i < bsz; i++) @@ -166,19 +169,10 @@ std::vector sample_basic for (int j = 0; j < vocab_size; j++) temp_indices[j] = j; int num_candidates = vocab_size; - if ((logit_temp_threshold > logit_min_threshold) && logit_min_threshold > 0.0f) + if (logit_temp_threshold > 0.0f) { - num_candidates = logit_threshold_restore - ( - logit_min_threshold, - logit_temp_threshold, - maxlogit, - vocab_size, - logits_ptr + i * vocab_size, - exponent, - temp_probs, - temp_indices - ); + num_candidates = logit_threshold_temperature(logit_temp_threshold, logit_high_temp, maxlogit, vocab_size, logits_ptr + i * vocab_size, exponent, temp_probs, temp_indices); + normalize_cpu(num_candidates, temp_probs); } if (num_candidates > top_k && top_k > 0 && top_k < vocab_size) diff --git a/exllamav2/exllamav2_ext/ext_sampling.h b/exllamav2/exllamav2_ext/ext_sampling.h index 51c4ed4c..e56ff73e 100644 --- a/exllamav2/exllamav2_ext/ext_sampling.h +++ b/exllamav2/exllamav2_ext/ext_sampling.h @@ -16,8 +16,6 @@ std::vector sample_basic ( torch::Tensor logits, // shape [bsz, vocab_size] float temperature, - float logit_temp_threshold, - float logit_min_threshold, int top_k, float top_p, float top_a, @@ -42,6 +40,8 @@ std::vector sample_basic float max_temp, float temp_exponent, float smoothing_factor, + float logit_temp_threshold, + float logit_high_temp, float skew ); diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index a4e2684c..de852568 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -73,9 +73,8 @@ class Settings: temperature_last: bool = False - logit_threshold_stats: bool = False logit_temp_threshold: float = 0.0 - logit_min_threshold: float = 0.0 + logit_high_temp: float = 0.0 confidence_breaker: int = 0 confidence_breaker_debug: bool = False @@ -535,18 +534,9 @@ def prep_logit_filter(lf): output_ktokens = torch.empty((batch_size, 1, return_top_tokens), dtype = torch.long) output_kprobs = torch.empty((batch_size, 1, return_top_tokens), dtype = torch.float) - if settings.logit_temp_threshold > 0.0 or settings.logit_min_threshold > 0.0: - logit_filter = prep_logit_filter(logit_filter) - effective_filter = max(settings.logit_temp_threshold, settings.logit_min_threshold) - logit_filter[logits.squeeze(1) < effective_filter] = False - if not torch.any(logit_filter): - logit_filter.view(-1)[torch.argmax(logits.squeeze(1))] = True - m = ext_c.sample_basic( logits, 1.0 if settings.temperature_last else settings.temperature, - settings.logit_temp_threshold, - settings.logit_min_threshold, settings.top_k, settings.top_p, settings.top_a, @@ -571,6 +561,8 @@ def prep_logit_filter(lf): settings.max_temp, settings.temp_exponent, settings.smoothing_factor, + settings.logit_temp_threshold, + settings.logit_high_temp, settings.skew ) @@ -596,39 +588,6 @@ def prep_logit_filter(lf): else: confidence_flag = None - if settings.logit_threshold_stats: - selected_token = output_tokens - batch_logits_squeezed = logits[0, 0, :] - token_logit = batch_logits_squeezed[selected_token] - min_logit_threshold = min(torch.max(batch_logits_squeezed).item(), - settings.logit_min_threshold if settings.logit_min_threshold > 0 - else settings.logit_temp_threshold) - filtered_indices_mask = batch_logits_squeezed >= min_logit_threshold - filtered_logits = batch_logits_squeezed[filtered_indices_mask] - probs = F.softmax(batch_logits_squeezed, dim=-1) - filtered_probs = probs[filtered_indices_mask] - - # Calculate the statistics for filtered_logits - min_filtered = filtered_logits.min().item() if len(filtered_logits) > 0 else float('nan') - mean_filtered = filtered_logits.mean().item() if len(filtered_logits) > 0 else float('nan') - max_filtered = filtered_logits.max().item() if len(filtered_logits) > 0 else float('nan') - std_filtered = filtered_logits.std().item() if len(filtered_logits) > 0 else float('nan') - min_p_equivalent = filtered_probs[filtered_logits.argmin()].item() - - debug_string = ( - f"total logits: {batch_logits_squeezed.size(0):<7} " - f"filtered to: {filtered_logits.size(0):<4} " - f"min: {min_filtered:>5.2f} " - f"mean: {mean_filtered:>5.2f} " - f"max: {max_filtered:>5.2f} " - f"std: {std_filtered:>5.2f} " - f"selected logit: {token_logit.item():>5.2f} " - f"selected token: {selected_token.item():<7} " - f"min_p: {min_p_equivalent:>6.5f}" - ) - print(debug_string, flush=True) - - if settings.mirostat: settings.mirostat_mu = m # Stop condition from filters From 618132c0af5a02dc35c141c50d4bea9fcc88e984 Mon Sep 17 00:00:00 2001 From: anchortense Date: Thu, 21 Nov 2024 11:10:01 +1100 Subject: [PATCH 4/4] Efficiency improvement --- exllamav2/exllamav2_ext/cpp/sampling.cpp | 38 ++++++++++-------------- exllamav2/exllamav2_ext/cpp/sampling.h | 7 ++--- exllamav2/exllamav2_ext/ext_sampling.cpp | 2 +- exllamav2/generator/dynamic.py | 2 +- 4 files changed, 20 insertions(+), 29 deletions(-) diff --git a/exllamav2/exllamav2_ext/cpp/sampling.cpp b/exllamav2/exllamav2_ext/cpp/sampling.cpp index ca308706..1bf83fc0 100644 --- a/exllamav2/exllamav2_ext/cpp/sampling.cpp +++ b/exllamav2/exllamav2_ext/cpp/sampling.cpp @@ -440,16 +440,15 @@ int sort_descending } AVX2_TARGET_OPTIONAL -int logit_threshold_temperature +void logit_threshold_temperature ( + const int num_candidates, float logit_temp_threshold, float logit_high_temp, const int maxlogit, - const int vocab_size, const float* logits, const float exponent, - float* temp_probs, - int* temp_indices + float* temp_probs ) { profile_start("logit_threshold_temperature"); @@ -457,10 +456,10 @@ int logit_threshold_temperature float esum = 0.0f; float static_pmass = 0.0f; float itemp = 1.0f / std::max(logit_high_temp, 0.01f); - int n = 0; float maxl = logits[maxlogit]; + std::vector above_threshold_indices; - for (int i = 0; i < vocab_size; i++) + for (int i = 0; i < num_candidates; i++) { float target_logit = logits[i]; @@ -474,37 +473,30 @@ int logit_threshold_temperature esum += e; if (target_logit >= logit_temp_threshold) + { temp_probs[i] = e; + above_threshold_indices.push_back(i); + } else static_pmass += temp_probs[i]; - - n++; } - float isum = 1.0f / esum; + float isum = (esum >= 0.0f) ? (1.0f / esum) : 1024.0f; float temp_pmass = 0.0f; - for (int i = 0; i < vocab_size; i++) + for (int i : above_threshold_indices) { - if (logits[i] >= logit_temp_threshold) - { - temp_probs[i] *= isum; - temp_pmass += temp_probs[i]; - } + temp_probs[i] *= isum; + temp_pmass += temp_probs[i]; } - float adjfactor = (1.0f - static_pmass) / temp_pmass; + float adjfactor = (temp_pmass >= 0.0f) ? ((1.0f - static_pmass) / temp_pmass) : 1024.0f; - for (int i = 0; i < vocab_size; i++) + for (int i : above_threshold_indices) { - if (logits[i] >= logit_temp_threshold) - temp_probs[i] *= adjfactor; + temp_probs[i] *= adjfactor; } - sort_descending(vocab_size, temp_probs, temp_indices, n); - profile_stop(); - return n; - } AVX2_TARGET_OPTIONAL diff --git a/exllamav2/exllamav2_ext/cpp/sampling.h b/exllamav2/exllamav2_ext/cpp/sampling.h index e10675b0..33669159 100644 --- a/exllamav2/exllamav2_ext/cpp/sampling.h +++ b/exllamav2/exllamav2_ext/cpp/sampling.h @@ -55,16 +55,15 @@ int sort_descending int max_index ); -int logit_threshold_temperature +void logit_threshold_temperature ( + const int num_candidates, float logit_temp_threshold, float logit_high_temp, const int maxlogit, - const int vocab_size, const float* logits, const float exponent, - float* temp_probs, - int* temp_indices + float* temp_probs ); int top_k_cpu diff --git a/exllamav2/exllamav2_ext/ext_sampling.cpp b/exllamav2/exllamav2_ext/ext_sampling.cpp index 64cef55b..c4704285 100644 --- a/exllamav2/exllamav2_ext/ext_sampling.cpp +++ b/exllamav2/exllamav2_ext/ext_sampling.cpp @@ -171,7 +171,7 @@ std::vector sample_basic if (logit_temp_threshold > 0.0f) { - num_candidates = logit_threshold_temperature(logit_temp_threshold, logit_high_temp, maxlogit, vocab_size, logits_ptr + i * vocab_size, exponent, temp_probs, temp_indices); + logit_threshold_temperature(num_candidates, logit_temp_threshold, logit_high_temp, maxlogit, logits_ptr + i * vocab_size, exponent, temp_probs); normalize_cpu(num_candidates, temp_probs); } diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index 8cd62898..e8e082f2 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -2002,7 +2002,7 @@ def emit( # End on stop tokens if next_token.item() in self.stop_tokens: - return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item()) + return emit(results, emit_eos = True, emit_held = True, eos_reason = "stop_token", stop_token = next_token.item()) # Stop if we reach max_new_tokens # TODO: Auto-extend option