1
+ from utils .utils import get_input_from_inputs
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+
6
+
7
+ def validate_classification_object (obj , object_type ):
8
+ predicted_object_keys = ["class" , "confidence" ]
9
+ if object_type == "gt" :
10
+ assert type (obj ) == str
11
+ elif object_type == "prediction" :
12
+ for k in predicted_object_keys :
13
+ if not k in obj .keys ():
14
+ raise TypeError ("The object is missing the key " , k )
15
+ else :
16
+ raise TypeError ("The object should have a gt or prediction type" )
17
+
18
+
19
+ def validate_detection_object (obj , object_type ):
20
+ gt_object_keys = ["category" , "bounding_box" ]
21
+ detected_object_keys = ["category" , "bounding_box" , "confidence" ]
22
+ bounding_box_format = ["x" , "y" , "w" , "h" ]
23
+
24
+ if object_type == "gt" :
25
+ for k in gt_object_keys :
26
+ if not k in obj .keys ():
27
+ raise TypeError ("" )
28
+ for i in bounding_box_format :
29
+ if not i in obj .get ("bounding_box" , {}).keys ():
30
+ raise TypeError ("" )
31
+
32
+ elif object_type == "detections" :
33
+ for k in detected_object_keys :
34
+ if not k in obj .keys ():
35
+ raise TypeError ("" )
36
+ for i in bounding_box_format :
37
+ if not i in obj .get ("bounding_box" , {}).keys ():
38
+ raise TypeError ("" )
39
+
40
+ else :
41
+ raise TypeError ("" )
42
+
43
+
44
+
45
+ def compute_confusion_matrix (gt_y , pred_y , labels ):
46
+ if labels is None :
47
+ labels = set (gt_y )
48
+
49
+ assert len (gt_y ) == len (pred_y )
50
+ confusion_matrix = np .zeros (len (labels ), len (labels ))
51
+ for i in range (len (gt_y )):
52
+ if gt_y [i ] == pred_y [i ]:
53
+ confusion_matrix [gt_y [i ], gt_y [i ]] += 1
54
+ else :
55
+ confusion_matrix [gt_y [i ], pred_y [i ]] += 1
56
+ confusion_matrix [pred_y [i ], gt_y [i ]] += 1
57
+
58
+ return confusion_matrix
59
+
60
+
61
+ def classification_metrics (inputs ):
62
+
63
+ gt_y = get_input_from_inputs (
64
+ inputs ,
65
+ "ground_truths" ,
66
+ expected_input_type = "array" ,
67
+ expected_list_type = "string"
68
+ )
69
+
70
+ pred_y = get_input_from_inputs (
71
+ inputs ,
72
+ "predictions" ,
73
+ expected_input_type = "array" ,
74
+ expected_list_type = "string"
75
+ )
76
+
77
+ # Get the list of the ground truths labels
78
+ labels = set (gt_y )
79
+
80
+ # The ground truths and the predictions should have the same length
81
+ assert len (gt_y ) == len (pred_y )
82
+
83
+ # Validate the ground truths and predictions
84
+
85
+ # Compute confusion matrix
86
+ confusion_matrix = compute_confusion_matrix (gt_y , pred_y , labels )
87
+
88
+ # Compute TP, TN, FP and FN for each class
89
+ results = dict .fromkeys (labels , {"TP" : 0 , "TN" : 0 , "FP" : 0 , "FN" : 0 })
90
+ for k , _ in results .items ():
91
+ tp = confusion_matrix [k , k ]
92
+ fp = confusion_matrix [:, k ] - tp
93
+ fn = confusion_matrix [k , :] - tp
94
+ results [k ]["TP" ] = tp
95
+ results [k ]["FP" ] = fp
96
+ results [k ]["FN" ] = fn
97
+ results [k ]["TN" ] = len (gt_y ) - tp - fp - fn
98
+
99
+ # Compute Precision and Recall for each class
100
+ precision_recall_per_class = dict .fromkeys (labels , {"Precision" : .0 , "Recall" : .0 })
101
+ for k , _ in results .items ():
102
+ precision_recall_per_class [k ]["Precision" ] = results [k ]["TP" ] / (results [k ]["TP" ] + results [k ]["FP" ])
103
+ precision_recall_per_class [k ]["Recall" ] = results [k ]["TP" ] / (results [k ]["TP" ] + results [k ]["FN" ])
104
+
105
+ precision = sum ([v ["Precision" ] for _ , v in precision_recall_per_class .items ()]) / len (labels )
106
+ recall = sum ([v ["Recall" ] for _ , v in precision_recall_per_class .items ()]) / len (labels )
107
+ f1_score = 2 * (precision * recall ) / (precision + recall )
108
+ total_metrics = {"Precision" : precision , "Recall" : recall , "f1-score" : f1_score }
109
+
110
+ return precision_recall_per_class , total_metrics
111
+
112
+
113
+ def compute_iou (bbox1 , bbox2 ):
114
+
115
+ # Compute the intersection of bbox1 and bbox2
116
+ xA = max (bbox1 [0 ], bbox2 [0 ])
117
+ yA = max (bbox1 [1 ], bbox2 [1 ])
118
+ xB = min (bbox1 [0 ] + bbox1 [2 ], bbox2 [0 ] + bbox2 [2 ])
119
+ yB = min (bbox1 [1 ] + bbox1 [3 ], bbox2 [1 ] + bbox2 [3 ])
120
+ intersection = (xB - xA + 1 ) * (yB - yA + 1 )
121
+
122
+ # Compute the union of bbox1 and bbox2
123
+ area1 = bbox1 [2 ] * bbox1 [3 ]
124
+ area2 = bbox2 [2 ] * bbox2 [3 ]
125
+ union = area1 + area2 - intersection
126
+
127
+ # Compute IoU
128
+ iou = float (intersection / union )
129
+
130
+ return iou
131
+
132
+
133
+ def object_detection_metrics (inputs ):
134
+ gt = get_input_from_inputs (
135
+ inputs ,
136
+ "ground_truths" ,
137
+ expected_input_type = "array" ,
138
+ expected_list_type = "object"
139
+ )
140
+
141
+ detections = get_input_from_inputs (
142
+ inputs ,
143
+ "detections" ,
144
+ expected_input_type = "array" ,
145
+ expected_list_type = "object"
146
+ )
147
+
148
+ categories = get_input_from_inputs (
149
+ inputs ,
150
+ "categories" ,
151
+ expected_input_type = "array" ,
152
+ expected_list_type = "string"
153
+ )
154
+
155
+ iou_threshold = get_input_from_inputs (
156
+ inputs ,
157
+ "iou_threshold" ,
158
+ expected_input_type = "number"
159
+ )
160
+
161
+ images_names = list (gt .keys ())
162
+ metrics_per_category = pd .DataFrame (data = np .zeros ((len (categories ), 6 )), index = categories ,
163
+ columns = ["All_GT" , "All_Detections" , "TP" , "FP" , "Precision" , "Recall" ])
164
+ metrics_per_image = pd .DataFrame (data = np .zeros ((len (images_names ), 6 )), index = images_names ,
165
+ columns = ["All_GT" , "All_Detections" , "TP" , "FP" , "Precision" , "Recall" ])
166
+
167
+ for idx in range (len (images_names )):
168
+ image_name = images_names [idx ]
169
+ gt_objects = gt .get (image_name , list ())
170
+ metrics_per_image .at [image_name , "All_GT" ] = len (gt_objects )
171
+ detected_objects = detections .get (image_name , list ())
172
+ ordered_detected_objects = sorted (detected_objects , key = lambda d : d ['confidence' ], reverse = True )
173
+ metrics_per_image .at [image_name , "All_Detections" ] = len (detected_objects )
174
+
175
+ for i in range (len (gt_objects )): # Loop through GT
176
+ category1 = gt_objects [i ]["category" ]
177
+ bbox1 = list (gt_objects [i ]["bounding_box" ].values ())
178
+ metrics_per_category .at [category1 , "All_GT" ] += 1
179
+
180
+ for j in range (len (ordered_detected_objects )): # Loop through Detections
181
+ category2 = ordered_detected_objects [j ]["category" ]
182
+ bbox2 = list (ordered_detected_objects [j ]["bounding_box" ].values ())
183
+ metrics_per_category .at [category2 , "All_Detections" ] += 1
184
+
185
+ iou = compute_iou (bbox1 , bbox2 )
186
+
187
+ if iou >= iou_threshold :
188
+ if not gt [image_name ][i ].get ("found" , False ):
189
+ if category1 == category2 :
190
+ gt [image_name ][i ]["found" ] = True
191
+ gt [image_name ][i ]["confidence" ] = ordered_detected_objects [j ]["confidence" ]
192
+ metrics_per_category .at [category1 , "TP" ] += 1
193
+ metrics_per_image .at [image_name , "TP" ] += 1
194
+ else :
195
+ gt [image_name ][i ]["found" ] = False
196
+ metrics_per_category .at [category2 , "FP" ] += 1
197
+ metrics_per_image .at [image_name , "FP" ] += 1
198
+ else :
199
+ metrics_per_category .at [category2 , "FP" ] += 1
200
+ metrics_per_image .at [image_name , "FP" ] += 1
201
+
202
+ else :
203
+ continue
204
+
205
+ for idx , row in metrics_per_category .iterrows ():
206
+ metrics_per_category .at [idx , "Precision" ] = metrics_per_category .loc [idx , "TP" ] / (metrics_per_category .loc [idx , "TP" ] + metrics_per_category .loc [idx , "FP" ])
207
+ metrics_per_category .at [idx , "Recall" ] = metrics_per_category .loc [idx , "TP" ] / metrics_per_category .loc [idx , "All_GT" ]
208
+
209
+ for idx , row in metrics_per_image .iterrows ():
210
+ metrics_per_image .at [idx , "Precision" ] = metrics_per_image .loc [idx , "TP" ] / (metrics_per_image .loc [idx , "TP" ] + metrics_per_image .loc [idx , "FP" ])
211
+ metrics_per_image .at [idx , "Recall" ] = metrics_per_image .loc [idx , "TP" ] / metrics_per_image .loc [idx , "All_GT" ]
212
+
213
+ # Covert data frames to Dicts
214
+
215
+ return [metrics_per_category , metrics_per_image , gt ]
216
+
0 commit comments