Skip to content

Commit ec323ee

Browse files
committed
smooth sampling v2
1 parent 124b7f3 commit ec323ee

File tree

2 files changed

+9
-31
lines changed

2 files changed

+9
-31
lines changed

ExtStuff.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
smoothing_factor=5.0
1+
smoothing_factor=0.5

llama.cpp

+8-30
Original file line numberDiff line numberDiff line change
@@ -8412,19 +8412,14 @@ void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * cand
84128412
printf("Token[%zu] = %f\n", i, candidates->data[i].p);
84138413
}
84148414

8415-
// Find min and max logits for normalization
8416-
float min_logit = candidates->data[0].logit;
8417-
float max_logit = candidates->data[0].logit;
8418-
for (size_t i = 1; i < candidates->size; ++i) {
8419-
if (candidates->data[i].logit < min_logit) min_logit = candidates->data[i].logit;
8420-
if (candidates->data[i].logit > max_logit) max_logit = candidates->data[i].logit;
8421-
}
8415+
float h = candidates->data[0].logit; // Find the maximum logit for h
8416+
float k = h; // Maximum logit value to be added after the transformation
84228417

84238418
// Read smoothing_factor from "ExtStuff.txt"
84248419
float smoothing_factor = 0;
84258420
FILE* file = fopen("ExtStuff.txt", "r");
84268421
if (file) {
8427-
if (fscanf(file, "smoothing_factor=%f", &smoothing_factor) != 1) { // Corrected variable name here
8422+
if (fscanf(file, "smoothing_factor=%f", &smoothing_factor) != 1) {
84288423
smoothing_factor = 0;
84298424
}
84308425
fclose(file);
@@ -8441,28 +8436,11 @@ void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * cand
84418436

84428437
// Only apply smoothing if smoothing_factor is not 0
84438438
if (smoothing_factor != 0) {
8444-
// Apply the remapping and sigmoid function
8445-
float new_min_logit = FLT_MAX;
8446-
float new_max_logit = FLT_MIN;
8447-
for (size_t i = 0; i < candidates->size; ++i) {
8448-
// Normalize the logits to the [0,1] range
8449-
float normalized_logit = (candidates->data[i].logit - min_logit) / (max_logit - min_logit);
8450-
8451-
// Apply the sigmoid function to the normalized logits
8452-
float sigmoid_logit = 1.0f / (1.0f + expf(-smoothing_factor * (normalized_logit - 0.5f)));
8453-
8454-
// Update the logits with the smoothed values
8455-
candidates->data[i].logit = sigmoid_logit * (max_logit - min_logit) + min_logit;
8456-
8457-
// Find new min and max logits after smoothing
8458-
if (candidates->data[i].logit < new_min_logit) new_min_logit = candidates->data[i].logit;
8459-
if (candidates->data[i].logit > new_max_logit) new_max_logit = candidates->data[i].logit;
8460-
}
8461-
8462-
// Rescale logits again so that new min and max logits match original min and max logits
8463-
for (size_t i = 0; i < candidates->size; ++i) {
8464-
candidates->data[i].logit = (candidates->data[i].logit - new_min_logit) / (new_max_logit - new_min_logit) * (max_logit - min_logit) + min_logit;
8465-
}
8439+
// Apply quadratic transformation using the smoothing_factor
8440+
for (size_t i = 0; i < candidates->size; ++i) {
8441+
float logit_shifted = candidates->data[i].logit - h;
8442+
candidates->data[i].logit = -smoothing_factor * logit_shifted * logit_shifted + k;
8443+
}
84668444

84678445
// Verbose print top and bottom 3 logits after smoothing
84688446
printf("\nTop 3 logits after smoothing:\n");

0 commit comments

Comments
 (0)