@@ -481,7 +481,7 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
481
481
}
482
482
483
483
int SampleLogits (const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
484
- int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order, llama_grammar * grammar)
484
+ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers> & sampler_order, llama_grammar * grammar, bool dynatemp, float min_temp, float max_temp )
485
485
{
486
486
int id = 0 ;
487
487
std::vector<llama_token_data> candidates;
@@ -540,7 +540,14 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
540
540
llama_sample_typical (nullptr , &candidates_p, typical_p,1 );
541
541
break ;
542
542
case KCPP_SAMPLER_TEMP:
543
- sample_temperature (&candidates_p, temp);
543
+ if (dynatemp)
544
+ {
545
+ llama_sample_entropy (nullptr , &candidates_p, temp, min_temp, max_temp);
546
+ }
547
+ else
548
+ {
549
+ sample_temperature (&candidates_p, temp);
550
+ }
544
551
break ;
545
552
case KCPP_SAMPLER_REP_PEN:
546
553
sample_rep_pen (n_ctx, rep_pen_range, rep_pen, presence_penalty, &candidates_p);
@@ -1479,6 +1486,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
1479
1486
}
1480
1487
1481
1488
std::string addedmemory = inputs.memory ;
1489
+
1482
1490
kcpp_params->prompt = inputs.prompt ;
1483
1491
kcpp_params->seed = inputs.seed ;
1484
1492
kcpp_params->n_predict = inputs.max_length ;
@@ -1494,10 +1502,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
1494
1502
kcpp_params->mirostat = inputs.mirostat ;
1495
1503
kcpp_params->mirostat_eta = inputs.mirostat_eta ;
1496
1504
kcpp_params->mirostat_tau = inputs.mirostat_tau ;
1505
+ kcpp_params->dynatemp = inputs.dynatemp ;
1506
+ kcpp_params->min_temp = inputs.min_temp ;
1507
+ kcpp_params->max_temp = inputs.max_temp ;
1497
1508
kcpp_params->n_ctx = inputs.max_context_length ;
1498
1509
kcpp_params->n_batch = n_batch;
1499
1510
kcpp_params->n_threads = n_threads;
1500
1511
kcpp_params->n_threads_batch = n_blasthreads;
1512
+
1501
1513
bool stream_sse = inputs.stream_sse ;
1502
1514
1503
1515
bool allow_regular_prints = (debugmode!=-1 && !inputs.quiet ) || debugmode >= 1 ;
@@ -1888,6 +1900,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
1888
1900
const float presence_penalty = kcpp_params->presence_penalty ;
1889
1901
const float typical_p = kcpp_params->typical_p ;
1890
1902
const float tfs_z = kcpp_params->tfs_z ;
1903
+ const float dynatemp = kcpp_params->dynatemp ;
1904
+ const float min_temp = kcpp_params->min_temp ;
1905
+ const float max_temp = kcpp_params->max_temp ;
1891
1906
1892
1907
if (!startedsampling)
1893
1908
{
@@ -1943,7 +1958,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
1943
1958
1944
1959
id = SampleLogits (logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, presence_penalty,
1945
1960
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
1946
- kcpp_params->mirostat , kcpp_params->mirostat_tau , kcpp_params->mirostat_eta , sampler_order, grammar);
1961
+ kcpp_params->mirostat , kcpp_params->mirostat_tau , kcpp_params->mirostat_eta , sampler_order, grammar, dynatemp, min_temp, max_temp );
1947
1962
1948
1963
if (grammar != nullptr ) {
1949
1964
grammar_accept_token (file_format, n_vocab, grammar, id);
0 commit comments