@@ -851,7 +851,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
851
851
852
852
LOG_INF (" %s : calculating hellaswag score over selected tasks.\n " , __func__);
853
853
854
- LOG (" \n task\t acc_norm\n " );
854
+ LOG (" \n task\t acc_norm\t 95%% confidence interval \ n" );
855
855
856
856
double acc = 0 .0f ;
857
857
@@ -985,8 +985,22 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
985
985
acc += 1.0 ;
986
986
}
987
987
988
- // Print the accumulated accuracy mean x 100
989
- LOG (" %zu\t %.8lf\n " , i + 1 , acc/double (i + 1 )*100.0 );
988
+ double freq = acc / double (i + 1 );
989
+
990
+ const double za = 1.95996398454 ;
991
+
992
+ // // Wald normal approx
993
+ // double conf =za*sqrt(freq*(1-freq)/double(i + 1));
994
+ // LOG("%zu\t%.8lf +/- %.8lf\n", i + 1, freq*100.0, conf*100.0);
995
+
996
+ // Wilson score interval, more accurate
997
+ double z = za * za / double (i + 1 );
998
+ double cnf = z * sqrt (double (i + 1 ) * (4.0 * freq * (1 - freq) + z)) / (za + za);
999
+ double a = (freq + z * 0.5 - cnf) / (1.0 + z);
1000
+ double b = (freq + z * 0.5 + cnf) / (1.0 + z);
1001
+
1002
+ // Print the accumulated accuracy mean x 100 and confidence interval
1003
+ LOG (" %zu\t %3.8lf%%\t [%3.4lf%%, %3.4lf%%]\n " , i + 1 , freq * 100.0 , a * 100.0 , b * 100.0 );
990
1004
}
991
1005
992
1006
i0 = i1 - 1 ;
0 commit comments