-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_utils.py
59 lines (55 loc) · 1.88 KB
/
plot_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from tabulate import tabulate # type: ignore
def print_table(results, headers, title="", format=False):
RED = "\033[103m"
UNDERLINE = "\033[4m"
END = "\033[0m"
if not format:
print(
tabulate(
sorted(
results,
key=lambda x: x[-1],
reverse=False,
),
headers=headers,
tablefmt="grid",
floatfmt=".2f",
)
)
return
# Format the numbers - highlight max in red and underline second max for each column
formatted_results = []
numeric_columns = list(zip(*[row[1:] for row in results])) # Exclude model names
for row in results:
formatted_row = [row[0]] # Start with model name
for i, value in enumerate(row[1:]):
column_values = numeric_columns[i]
max_val = max(column_values)
if len(sorted(column_values)) >= 2:
second_max = sorted(column_values)[-2]
else:
second_max = max_val
if abs(value - max_val) < 1e-10: # Using small epsilon for float comparison
formatted_row.append(f"{RED}{value}{END}")
elif abs(value - second_max) < 1e-10:
formatted_row.append(f"{UNDERLINE}{value}{END}")
else:
formatted_row.append(f"{value}")
formatted_results.append(formatted_row)
# Show table title if provided
if title:
print(f"\n{title}\n")
print(
tabulate(
sorted(
formatted_results,
key=lambda x: float(
x[-1].replace(RED, "").replace(UNDERLINE, "").replace(END, "")
),
reverse=False,
),
headers=headers,
tablefmt="grid",
floatfmt=".2f",
)
)