Skip to content

Commit 79aeadf

Browse files
authored
Merge Dynamic Temp UI + cleanups into dynatemp-pr-upstream
2 parents 9e0dee7 + f61a441 commit 79aeadf

9 files changed

+182
-7
lines changed

common/common.h

+6
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ struct gpt_params {
8080
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
8181
float mirostat_tau = 5.00f; // target entropy
8282
float mirostat_eta = 0.10f; // learning rate
83+
84+
// DynaTemp!
85+
bool dynatemp = false; // enable DynaTemp
86+
float min_temp = 0.00f; // minimum temperature
87+
float max_temp = 2.00f; // maximum temperature
88+
8389
// // sampling parameters
8490
struct llama_sampling_params sparams;
8591

common/sampling.h

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ typedef struct llama_sampling_params {
2525
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
2626
float mirostat_tau = 5.00f; // target entropy
2727
float mirostat_eta = 0.10f; // learning rate
28+
bool dynatemp = false; // dynamic temperature
29+
float min_temp = 0.00f; // minimum temperature
30+
float max_temp = 2.00f; // maximum temperature
2831
bool penalize_nl = true; // consider newlines as a repeatable token
2932
std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
3033

expose.h

+4
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@ struct generation_inputs
8181
const char * grammar;
8282
const bool grammar_retain_state;
8383
const bool quiet = false;
84+
const bool dynatemp = false;
85+
const float min_temp;
86+
const float max_temp;
8487
const logit_bias logit_biases[logit_bias_max];
88+
8589
};
8690
struct generation_outputs
8791
{

gpttype_adapter.cpp

+18-3
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
481481
}
482482

483483
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
484-
int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order, llama_grammar * grammar)
484+
int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order, llama_grammar * grammar, bool dynatemp, float min_temp, float max_temp)
485485
{
486486
int id = 0;
487487
std::vector<llama_token_data> candidates;
@@ -540,7 +540,14 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
540540
llama_sample_typical(nullptr, &candidates_p, typical_p,1);
541541
break;
542542
case KCPP_SAMPLER_TEMP:
543-
sample_temperature(&candidates_p, temp);
543+
if (dynatemp)
544+
{
545+
llama_sample_entropy(nullptr, &candidates_p, temp, min_temp, max_temp);
546+
}
547+
else
548+
{
549+
sample_temperature(&candidates_p, temp);
550+
}
544551
break;
545552
case KCPP_SAMPLER_REP_PEN:
546553
sample_rep_pen(n_ctx, rep_pen_range, rep_pen, presence_penalty, &candidates_p);
@@ -1479,6 +1486,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
14791486
}
14801487

14811488
std::string addedmemory = inputs.memory;
1489+
14821490
kcpp_params->prompt = inputs.prompt;
14831491
kcpp_params->seed = inputs.seed;
14841492
kcpp_params->n_predict = inputs.max_length;
@@ -1494,10 +1502,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
14941502
kcpp_params->mirostat = inputs.mirostat;
14951503
kcpp_params->mirostat_eta = inputs.mirostat_eta;
14961504
kcpp_params->mirostat_tau = inputs.mirostat_tau;
1505+
kcpp_params->dynatemp = inputs.dynatemp;
1506+
kcpp_params->min_temp = inputs.min_temp;
1507+
kcpp_params->max_temp = inputs.max_temp;
14971508
kcpp_params->n_ctx = inputs.max_context_length;
14981509
kcpp_params->n_batch = n_batch;
14991510
kcpp_params->n_threads = n_threads;
15001511
kcpp_params->n_threads_batch = n_blasthreads;
1512+
15011513
bool stream_sse = inputs.stream_sse;
15021514

15031515
bool allow_regular_prints = (debugmode!=-1 && !inputs.quiet) || debugmode >= 1;
@@ -1888,6 +1900,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
18881900
const float presence_penalty = kcpp_params->presence_penalty;
18891901
const float typical_p = kcpp_params->typical_p;
18901902
const float tfs_z = kcpp_params->tfs_z;
1903+
const float dynatemp = kcpp_params->dynatemp;
1904+
const float min_temp = kcpp_params->min_temp;
1905+
const float max_temp = kcpp_params->max_temp;
18911906

18921907
if (!startedsampling)
18931908
{
@@ -1943,7 +1958,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
19431958

19441959
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, presence_penalty,
19451960
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
1946-
kcpp_params->mirostat, kcpp_params->mirostat_tau, kcpp_params->mirostat_eta, sampler_order, grammar);
1961+
kcpp_params->mirostat, kcpp_params->mirostat_tau, kcpp_params->mirostat_eta, sampler_order, grammar, dynatemp, min_temp, max_temp);
19471962

19481963
if (grammar != nullptr) {
19491964
grammar_accept_token(file_format, n_vocab, grammar, id);

kcpp_docs.embd

+16-1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,21 @@
139139
"description": "If true, prevents the EOS token from being generated (Ban EOS). For unbantokens, set this to false.",
140140
"type": "boolean"
141141
},
142+
"dynatemp": {
143+
"default": false,
144+
"description": "If true, uses dynamic temperature. If false, uses static temperature.",
145+
"type": "boolean"
146+
},
147+
"min_temp": {
148+
"description": "Dynatemp Minimum temperature value.",
149+
"exclusiveMinimum": 0,
150+
"type": "number"
151+
},
152+
"max_temp": {
153+
"description": "Maximum temperature value.",
154+
"exclusiveMinimum": 0,
155+
"type": "number"
156+
},
142157
"mirostat": {
143158
"description": "KoboldCpp ONLY. Sets the mirostat mode, 0=disabled, 1=mirostat_v1, 2=mirostat_v2",
144159
"minimum": 0,
@@ -876,4 +891,4 @@
876891

877892
</body>
878893

879-
</html>
894+
</html>

0 commit comments

Comments
 (0)