Skip to content

Commit 5d4ff64

Browse files
added polars-based accuracy summary to the benchmark
1 parent fbecc51 commit 5d4ff64

File tree

1 file changed

+29
-7
lines changed

1 file changed

+29
-7
lines changed

benchmark/dbally_benchmark/context_benchmark.py

+29-7
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring
2+
import polars as pl
3+
24
import dbally
35
import asyncio
46
import typing
57
import json
68
import traceback
79
import os
8-
910
import tqdm.asyncio
1011
import sqlalchemy
11-
import pydantic
1212
from typing_extensions import TypeAlias
1313
from copy import deepcopy
1414
from sqlalchemy import create_engine
@@ -31,7 +31,8 @@
3131
Candidate = Base.classes.candidates
3232

3333

34-
class MyData(BaseCallerContext, pydantic.BaseModel):
34+
@dataclass
35+
class MyData(BaseCallerContext):
3536
first_name: str
3637
surname: str
3738
position: str
@@ -41,7 +42,8 @@ class MyData(BaseCallerContext, pydantic.BaseModel):
4142
country: str
4243

4344

44-
class OpenPosition(BaseCallerContext, pydantic.BaseModel):
45+
@dataclass
46+
class OpenPosition(BaseCallerContext):
4547
position: str
4648
min_years_of_experience: int
4749
graduated_from_university: str
@@ -130,7 +132,7 @@ def first_name_is(self, first_name: typing.Union[str, MyData]) -> sqlalchemy.Col
130132
return Candidate.name.startswith(first_name)
131133

132134

133-
OpenAILLMName: TypeAlias = typing.Literal['gpt-3.5-turbo', 'gpt-4-turbo', 'gpt-4o']
135+
OpenAILLMName: TypeAlias = typing.Literal['gpt-3.5-turbo', 'gpt-3.5-turbo-instruct', 'gpt-4-turbo', 'gpt-4o']
134136

135137

136138
def setup_collection(model_name: OpenAILLMName) -> dbally.Collection:
@@ -224,10 +226,30 @@ async def main(config: BenchmarkConfig):
224226

225227
output_data[question]["answers"][llm_name].append(answer)
226228

227-
output_data_list = list(output_data.values())
229+
df_out_raw = pl.DataFrame(list(output_data.values()))
230+
231+
df_out = (
232+
df_out_raw
233+
.unnest("answers")
234+
.unpivot(
235+
on=pl.selectors.starts_with("gpt"),
236+
index=["question", "correct_answer", "context"],
237+
variable_name="model",
238+
value_name="answer"
239+
)
240+
.explode("answer")
241+
.group_by(["context", "model"])
242+
.agg([
243+
(pl.col("correct_answer") == pl.col("answer")).mean().alias("frac_hits"),
244+
(pl.col("correct_answer") == pl.col("answer")).sum().alias("n_hits"),
245+
])
246+
.sort(["model", "context"])
247+
)
248+
249+
print(df_out)
228250

229251
with open(config.out_path, 'w') as file:
230-
file.write(json.dumps(test_set, indent=2))
252+
file.write(json.dumps(df_out_raw.to_dicts(), indent=2))
231253

232254

233255
if __name__ == "__main__":

0 commit comments

Comments
 (0)