-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprediction_engine.py
36 lines (27 loc) · 1.11 KB
/
prediction_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import sqrt_distance_classifier
import abs_distance_classifier
import percent_distance_classifier
import stddev_classifier
import knn
# returns a converted point from a dataframe
def convert_point(point, int_map):
converted = []
for x in point:
if x in int_map:
converted.append(int_map[x])
else:
converted.append(x)
return converted
# predict function
def predict(point, model):
converted = convert_point(point, model.int_map)
if model.algorithm == 'sqrt_distance_classifier':
return sqrt_distance_classifier.classify(converted, model.mean_map)
if model.algorithm == 'abs_distance_classifier':
return abs_distance_classifier.classify(converted, model.mean_map)
if model.algorithm == 'percent_distance_classifier':
return percent_distance_classifier.classify(converted, model.mean_map)
if model.algorithm == 'stddev_distance_classifier':
return stddev_classifier.classify(converted, model.stddev_map, model.mean_map)
if model.algorithm == 'knn':
return knn.classify(converted, model.df_sampled)