1
1
# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring
2
+ import polars as pl
3
+
2
4
import dbally
3
5
import asyncio
4
6
import typing
5
7
import json
6
8
import traceback
7
9
import os
8
-
9
10
import tqdm .asyncio
10
11
import sqlalchemy
11
- import pydantic
12
12
from typing_extensions import TypeAlias
13
13
from copy import deepcopy
14
14
from sqlalchemy import create_engine
31
31
Candidate = Base .classes .candidates
32
32
33
33
34
- class MyData (BaseCallerContext , pydantic .BaseModel ):
34
+ @dataclass
35
+ class MyData (BaseCallerContext ):
35
36
first_name : str
36
37
surname : str
37
38
position : str
@@ -41,7 +42,8 @@ class MyData(BaseCallerContext, pydantic.BaseModel):
41
42
country : str
42
43
43
44
44
- class OpenPosition (BaseCallerContext , pydantic .BaseModel ):
45
+ @dataclass
46
+ class OpenPosition (BaseCallerContext ):
45
47
position : str
46
48
min_years_of_experience : int
47
49
graduated_from_university : str
@@ -130,7 +132,7 @@ def first_name_is(self, first_name: typing.Union[str, MyData]) -> sqlalchemy.Col
130
132
return Candidate .name .startswith (first_name )
131
133
132
134
133
- OpenAILLMName : TypeAlias = typing .Literal ['gpt-3.5-turbo' , 'gpt-4-turbo' , 'gpt-4o' ]
135
+ OpenAILLMName : TypeAlias = typing .Literal ['gpt-3.5-turbo' , 'gpt-3.5-turbo-instruct' , 'gpt- 4-turbo' , 'gpt-4o' ]
134
136
135
137
136
138
def setup_collection (model_name : OpenAILLMName ) -> dbally .Collection :
@@ -224,10 +226,30 @@ async def main(config: BenchmarkConfig):
224
226
225
227
output_data [question ]["answers" ][llm_name ].append (answer )
226
228
227
- output_data_list = list (output_data .values ())
229
+ df_out_raw = pl .DataFrame (list (output_data .values ()))
230
+
231
+ df_out = (
232
+ df_out_raw
233
+ .unnest ("answers" )
234
+ .unpivot (
235
+ on = pl .selectors .starts_with ("gpt" ),
236
+ index = ["question" , "correct_answer" , "context" ],
237
+ variable_name = "model" ,
238
+ value_name = "answer"
239
+ )
240
+ .explode ("answer" )
241
+ .group_by (["context" , "model" ])
242
+ .agg ([
243
+ (pl .col ("correct_answer" ) == pl .col ("answer" )).mean ().alias ("frac_hits" ),
244
+ (pl .col ("correct_answer" ) == pl .col ("answer" )).sum ().alias ("n_hits" ),
245
+ ])
246
+ .sort (["model" , "context" ])
247
+ )
248
+
249
+ print (df_out )
228
250
229
251
with open (config .out_path , 'w' ) as file :
230
- file .write (json .dumps (test_set , indent = 2 ))
252
+ file .write (json .dumps (df_out_raw . to_dicts () , indent = 2 ))
231
253
232
254
233
255
if __name__ == "__main__" :
0 commit comments