diff --git a/exllamav2/exllamav2_ext/cpp/sampling.cpp b/exllamav2/exllamav2_ext/cpp/sampling.cpp
index 8f9ae7e6..1bf83fc0 100644
--- a/exllamav2/exllamav2_ext/cpp/sampling.cpp
+++ b/exllamav2/exllamav2_ext/cpp/sampling.cpp
@@ -439,6 +439,66 @@ int sort_descending
     return pre;
 }
 
+AVX2_TARGET_OPTIONAL
+void logit_threshold_temperature
+(
+    const int num_candidates,
+    float logit_temp_threshold,
+    float logit_high_temp,
+    const int maxlogit,
+    const float* logits,
+    const float exponent,
+    float* temp_probs
+)
+{
+    profile_start("logit_threshold_temperature");
+
+    float esum = 0.0f;
+    float static_pmass = 0.0f;
+    float itemp = 1.0f / std::max(logit_high_temp, 0.01f);
+    float maxl = logits[maxlogit];
+    std::vector<int> above_threshold_indices;
+
+    for (int i = 0; i < num_candidates; i++)
+    {
+        float target_logit = logits[i];
+
+        float l = target_logit - maxl;
+        if (exponent == 2.0f)
+            l *= -l;
+        else if (exponent != 1.0f)
+            l = -powf(fabs(l), exponent);
+
+        float e = expf(l * itemp);
+        esum += e;
+
+        if (target_logit >= logit_temp_threshold)
+        {
+            temp_probs[i] = e;
+            above_threshold_indices.push_back(i);
+        }
+        else static_pmass += temp_probs[i];
+    }
+
+    float isum = (esum >= 0.0f) ? (1.0f / esum) : 1024.0f;
+    float temp_pmass = 0.0f;
+
+    for (int i : above_threshold_indices)
+    {
+        temp_probs[i] *= isum;
+        temp_pmass += temp_probs[i];
+    }
+
+    float adjfactor = (temp_pmass >= 0.0f) ? ((1.0f - static_pmass) / temp_pmass) : 1024.0f;
+
+    for (int i : above_threshold_indices)
+    {
+        temp_probs[i] *= adjfactor;
+    }
+
+    profile_stop();
+}
+
 AVX2_TARGET_OPTIONAL
 int top_k_cpu
 (
diff --git a/exllamav2/exllamav2_ext/cpp/sampling.h b/exllamav2/exllamav2_ext/cpp/sampling.h
index 814f90b2..33669159 100644
--- a/exllamav2/exllamav2_ext/cpp/sampling.h
+++ b/exllamav2/exllamav2_ext/cpp/sampling.h
@@ -55,6 +55,17 @@ int sort_descending
     int max_index
 );
 
+void logit_threshold_temperature
+(
+    const int num_candidates,
+    float logit_temp_threshold,
+    float logit_high_temp,
+    const int maxlogit,
+    const float* logits,
+    const float exponent,
+    float* temp_probs
+);
+
 int top_k_cpu
 (
     const int num_candidates,
diff --git a/exllamav2/exllamav2_ext/ext_sampling.cpp b/exllamav2/exllamav2_ext/ext_sampling.cpp
index 2bf312b1..c4704285 100644
--- a/exllamav2/exllamav2_ext/ext_sampling.cpp
+++ b/exllamav2/exllamav2_ext/ext_sampling.cpp
@@ -109,6 +109,8 @@ std::vector<float> sample_basic
     float max_temp = 0.0f,
     float temp_exponent = 1.0f,
     float smoothing_factor = 0.0f,
+    float logit_temp_threshold = 0.0f,
+    float logit_high_temp = 0.0f,
     float skew = 0.0f
 )
 {
@@ -138,8 +140,11 @@ std::vector<float> sample_basic
 
     if (temperature < 0.01)
     {
-        temperature = 1.0f;
-        top_k = 1;
+        if (logit_temp_threshold == 0.0f)
+        {
+            temperature = 1.0f;
+            top_k = 1;
+        }
     }
 
     for (int i = 0; i < bsz; i++)
@@ -164,7 +169,13 @@ std::vector<float> sample_basic
         for (int j = 0; j < vocab_size; j++) temp_indices[j] = j;
         int num_candidates = vocab_size;
 
-        if (top_k > 0 && top_k < vocab_size)
+        if (logit_temp_threshold > 0.0f)
+        {
+            logit_threshold_temperature(num_candidates, logit_temp_threshold, logit_high_temp, maxlogit, logits_ptr + i * vocab_size, exponent, temp_probs);
+            normalize_cpu(num_candidates, temp_probs);
+        }
+
+        if (num_candidates > top_k && top_k > 0 && top_k < vocab_size)
         {
             num_candidates = top_k_cpu(num_candidates, temp_probs, temp_indices, top_k, maxlogit);
             normalize_cpu(num_candidates, temp_probs);
diff --git a/exllamav2/exllamav2_ext/ext_sampling.h b/exllamav2/exllamav2_ext/ext_sampling.h
index 45d01a52..e56ff73e 100644
--- a/exllamav2/exllamav2_ext/ext_sampling.h
+++ b/exllamav2/exllamav2_ext/ext_sampling.h
@@ -40,6 +40,8 @@ std::vector<float> sample_basic
     float max_temp,
     float temp_exponent,
     float smoothing_factor,
+    float logit_temp_threshold,
+    float logit_high_temp,
     float skew
 );
 
diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py
index bce25ae8..e8e082f2 100644
--- a/exllamav2/generator/dynamic.py
+++ b/exllamav2/generator/dynamic.py
@@ -525,7 +525,7 @@ def set_loras(self, loras: list[ExLlamaV2Lora] | None):
             self.current_loras = loras
         else:
             self.current_loras = [loras]
-        
+
 
     def generate(
         self,
@@ -1170,10 +1170,10 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None):
             for i in range(batch_logits.shape[1]):
                 job_logits = batch_logits[a:b, i:i+1, :]
                 if i == 0 and mt_sample:
-                    next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \
+                    next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag = \
                     futures.popleft().result()
                 else:
-                    next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \
+                    next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag = \
                     job.receive_logits(job_logits)
 
                 eos, sampled_token = job.receive_sample(
@@ -1183,6 +1183,7 @@ def iterate_gen(self, results: list, draft_tokens: torch.Tensor | None = None):
                     next_k_probs,
                     next_prob,
                     filter_eos,
+                    confidence_flag,
                     results
                 )
 
@@ -1661,6 +1662,12 @@ def __init__(
 
         self.checkpoint = None
 
+        # Confidence breaker
+
+        self.confidence_breaker = gen_settings.confidence_breaker
+        self.confidence_breaker_debug = gen_settings.confidence_breaker_debug
+        self.confidence_flag_sequence = [False] * self.confidence_breaker
+
         # Measurement
 
         self.time_enqueue = None
@@ -1762,7 +1769,7 @@ def receive_logits(
             else:
                 blocked_tokens = self.stop_tokens_list
 
-        next_token, next_k_tokens, next_k_probs, next_prob, filter_eos = \
+        next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag = \
         ExLlamaV2Sampler.sample(
             logits,
             self.gen_settings,
@@ -1777,7 +1784,7 @@ def receive_logits(
             # sync = True
         )
 
-        return next_token, next_k_tokens, next_k_probs, next_prob, filter_eos
+        return next_token, next_k_tokens, next_k_probs, next_prob, filter_eos, confidence_flag
 
 
     def receive_sample(
@@ -1788,6 +1795,7 @@ def receive_sample(
             next_k_probs: torch.Tensor | None,
             next_prob: torch.Tensor | None,
             filter_eos: bool | None,
+            confidence_flag: bool | None,
             results: list
     ):
         page_size = self.generator.page_size
@@ -1801,6 +1809,14 @@ def receive_sample(
                 if f.use_background_worker():
                     self.generator.filter_queue.append(f)
 
+        # Update confidence_flag_sequence
+
+        if confidence_flag is not None:
+            self.confidence_flag_sequence.append(confidence_flag)
+            # Limit the size of the sequence to prevent it from growing indefinitely
+            if len(self.confidence_flag_sequence) > self.confidence_breaker + 1:
+                self.confidence_flag_sequence.pop(0)
+
         # Accept token
 
         self.new_tokens += 1
@@ -1986,7 +2002,7 @@ def emit(
         # End on stop tokens
 
         if next_token.item() in self.stop_tokens:
-            return emit(results, emit_eos = True, eos_reason = "stop_token", stop_token = next_token.item())
+            return emit(results, emit_eos = True, emit_held = True, eos_reason = "stop_token", stop_token = next_token.item())
 
         # Stop if we reach max_new_tokens
         # TODO: Auto-extend option
@@ -2011,7 +2027,8 @@ def emit(
             else:
                 return emit(results)
 
-        # Hold text as long as it contains part of a banned string
+        # Hold text as long as it contains part of a banned string,
+        # or until we know a confidence breaker will not be triggered
 
         def unset_checkpoint():
             self.checkpoint = None
@@ -2026,6 +2043,7 @@ def set_checkpoint():
                     "held_k_tokens": self.held_k_tokens.clone(1),
                     "held_k_probs": self.held_k_probs.clone(1),
                     "held_logits": self.held_logits.clone(1),
+                    "flag_sequence": self.confidence_flag_sequence[:-1].copy(),
                     "explored_tokens": [next_token.item()],
                 }
             else:
@@ -2054,28 +2072,59 @@ def rewind_checkpoint():
             off_tokens = self.held_tokens.slice(len(self.checkpoint["held_tokens"]), None)
             off_text = self.held_text[len(self.checkpoint["held_text"]):]
             self.held_text = self.checkpoint["held_text"]
-            self.held_token = self.checkpoint["held_tokens"]
+            self.held_tokens = self.checkpoint["held_tokens"]
             self.held_probs = self.checkpoint["held_probs"]
             self.held_k_tokens = self.checkpoint["held_k_tokens"]
             self.held_k_probs = self.checkpoint["held_k_probs"]
             self.held_logits = self.checkpoint["held_logits"]
+            self.confidence_flag_sequence = self.checkpoint["flag_sequence"]
             self.checkpoint["offset"] = 0
             return off_tokens, off_text
 
-        if self.banned_strings_utf32_offsets is not None and self.new_tokens > 0:
-            match = ext_c.partial_strings_match(
-                np.frombuffer(self.held_text.lower().encode("utf-32-le"), dtype = np.uint8),
-                self.banned_strings_utf32_offsets,
-                self.banned_strings_utf32_buffer
-            )
-            if match >= 0:
+        # Handle banned strings and confidence flags using checkpointing
+
+        if self.new_tokens > 0:
+            # Check for banned strings
+            banned_string_match = -1
+            if self.banned_strings_utf32_offsets is not None:
+                banned_string_match = ext_c.partial_strings_match(
+                    np.frombuffer(self.held_text.lower().encode("utf-32-le"), dtype = np.uint8),
+                    self.banned_strings_utf32_offsets,
+                    self.banned_strings_utf32_buffer
+                )
+
+            confidence_breaker_match = -1
+            if self.confidence_breaker > 0:
+                # Check for confidence_flag sequence
+                if confidence_flag is not None:
+                    last_n_flags = self.confidence_flag_sequence[-self.confidence_breaker:]
+                    if not confidence_flag:
+                        confidence_breaker_match = -1  # False flag
+                    elif all(last_n_flags):
+                        confidence_breaker_match = 1  # Match
+                    else:
+                        confidence_breaker_match = -2  # Partial match, wait and see
+                elif self.confidence_flag_sequence[-1]:
+                    confidence_breaker_match = -2  # Pause current sequence without resetting partial match
+                else:
+                    confidence_breaker_match = -1  # Treat None as False flag, following previous False flag
+
+            if confidence_breaker_match >= 0:  # Match confidence breaker
+                set_checkpoint()
+                if self.confidence_breaker_debug:
+                    print(f'[Confidence breaker activated on text: "{self.held_text}"]', flush=True)
+                offending_tokens, offending_text = rewind_checkpoint()
+                return emit(results, suppressed_text = offending_text, suppressed_tokens = offending_tokens)
+            elif banned_string_match >= 0:
                 set_checkpoint()
                 offending_tokens, offending_text = rewind_checkpoint()
-                return emit(results, emit_held = True, suppressed_text = offending_text, suppressed_tokens = offending_tokens)
-            elif match == -2:
+                return emit(results, suppressed_text = offending_text, suppressed_tokens = offending_tokens)
+            elif banned_string_match == -2 or confidence_breaker_match == -2:  # Partial match
                 set_checkpoint()
                 return emit(results)
-            else:
+            else:  # Reset and permit text passthrough
+                if len(self.full_completion) > 0:
+                    set_checkpoint()
                 unset_checkpoint()
 
         # End on stop strings
diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py
index dcad0ee1..de852568 100644
--- a/exllamav2/generator/sampler.py
+++ b/exllamav2/generator/sampler.py
@@ -73,6 +73,14 @@ class Settings:
 
         temperature_last: bool = False
 
+        logit_temp_threshold: float = 0.0
+        logit_high_temp: float = 0.0
+
+        confidence_breaker: int = 0
+        confidence_breaker_debug: bool = False
+        cb_mid_threshold: float = 0.0
+        cb_high_threshold: float = 0.0
+
         mirostat: bool = False
         mirostat_tau: float = 1.5
         mirostat_eta: float = 0.1
@@ -420,6 +428,7 @@ def prep_logit_filter(lf):
         # Temporarily ban individual tokens
 
         if blocked_tokens:
+            saved_logits = logits[:, :, blocked_tokens].clone()
             logits[:, :, blocked_tokens] = -1e30
 
         # Token bias
@@ -552,9 +561,33 @@ def prep_logit_filter(lf):
             settings.max_temp,
             settings.temp_exponent,
             settings.smoothing_factor,
+            settings.logit_temp_threshold,
+            settings.logit_high_temp,
             settings.skew
         )
 
+        if settings.confidence_breaker > 0:
+            if blocked_tokens and 'saved_logits' in locals():
+                # Restore the saved logits values for the blocked tokens
+                logits[:, :, blocked_tokens] = saved_logits
+                
+            squeezed_logits = logits.squeeze(0).squeeze(0)
+            probs = F.softmax(squeezed_logits, dim=-1)
+            token_prob = probs[output_tokens]
+            token_logit = squeezed_logits[output_tokens]
+            if settings.cb_mid_threshold <= 1.0:
+                confidence_flag = (token_prob >= settings.cb_mid_threshold).item()
+            else:
+                confidence_flag = (token_logit >= settings.cb_mid_threshold).item()
+            if settings.cb_high_threshold <= 1.0:
+                if (token_prob > settings.cb_high_threshold).item():
+                    confidence_flag = None
+            else:
+                if (token_logit > settings.cb_high_threshold).item():
+                    confidence_flag = None
+        else:
+            confidence_flag = None
+
         if settings.mirostat: settings.mirostat_mu = m
 
         # Stop condition from filters
@@ -563,4 +596,4 @@ def prep_logit_filter(lf):
         if len(filters) > 0 and end_tokens is not None and output_tokens[0].item() in end_tokens:
             end_filter = True
 
-        return output_tokens, output_ktokens, output_kprobs, output_probs, end_filter
+        return output_tokens, output_ktokens, output_kprobs, output_probs, end_filter, confidence_flag