Skip to content

Commit 4ccea21

Browse files
authored
hellaswag: display estimated score confidence interval (ggml-org#12797)
1 parent 1a1ab7e commit 4ccea21

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
851851

852852
LOG_INF("%s : calculating hellaswag score over selected tasks.\n", __func__);
853853

854-
LOG("\ntask\tacc_norm\n");
854+
LOG("\ntask\tacc_norm\t95%% confidence interval\n");
855855

856856
double acc = 0.0f;
857857

@@ -985,8 +985,22 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
985985
acc += 1.0;
986986
}
987987

988-
// Print the accumulated accuracy mean x 100
989-
LOG("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
988+
double freq = acc / double(i + 1);
989+
990+
const double za = 1.95996398454;
991+
992+
// // Wald normal approx
993+
// double conf =za*sqrt(freq*(1-freq)/double(i + 1));
994+
// LOG("%zu\t%.8lf +/- %.8lf\n", i + 1, freq*100.0, conf*100.0);
995+
996+
// Wilson score interval, more accurate
997+
double z = za * za / double(i + 1);
998+
double cnf = z * sqrt(double(i + 1) * (4.0 * freq * (1 - freq) + z)) / (za + za);
999+
double a = (freq + z * 0.5 - cnf) / (1.0 + z);
1000+
double b = (freq + z * 0.5 + cnf) / (1.0 + z);
1001+
1002+
// Print the accumulated accuracy mean x 100 and confidence interval
1003+
LOG("%zu\t%3.8lf%%\t[%3.4lf%%, %3.4lf%%]\n", i + 1, freq * 100.0, a * 100.0, b * 100.0);
9901004
}
9911005

9921006
i0 = i1 - 1;

0 commit comments

Comments
 (0)