1
1
from typing import Dict , List , Tuple , Optional
2
2
3
+ import numpy as np
3
4
from dataprep_ml import StatisticalAnalysis
4
5
5
6
from lightwood .helpers .log import log
8
9
from lightwood .analysis .base import BaseAnalysisBlock
9
10
from lightwood .data .encoded_ds import EncodedDs
10
11
from lightwood .encoder .text .pretrained import PretrainedLangEncoder
11
- from lightwood .api .types import ModelAnalysis , TimeseriesSettings , PredictionArguments
12
+ from lightwood .api .types import ModelAnalysis , ProblemDefinition , PredictionArguments
12
13
13
14
14
15
def model_analyzer (
@@ -17,7 +18,7 @@ def model_analyzer(
17
18
train_data : EncodedDs ,
18
19
stats_info : StatisticalAnalysis ,
19
20
target : str ,
20
- tss : TimeseriesSettings ,
21
+ pdef : ProblemDefinition ,
21
22
dtype_dict : Dict [str , str ],
22
23
accuracy_functions ,
23
24
ts_analysis : Dict ,
@@ -39,54 +40,64 @@ def model_analyzer(
39
40
40
41
runtime_analyzer = {}
41
42
data_type = dtype_dict [target ]
43
+ tss = pdef .timeseries_settings
42
44
43
45
# retrieve encoded data representations
44
46
encoded_train_data = train_data
45
47
encoded_val_data = data
46
48
data = encoded_val_data .data_frame
47
49
input_cols = list ([col for col in data .columns if col != target ])
48
50
49
- # predictive task
50
- is_numerical = data_type in (dtype .integer , dtype .float , dtype .num_tsarray , dtype .quantity )
51
- is_classification = data_type in (dtype .categorical , dtype .binary , dtype .cat_tsarray )
52
- is_multi_ts = tss .is_timeseries and tss .horizon > 1
53
- has_pretrained_text_enc = any ([isinstance (enc , PretrainedLangEncoder )
54
- for enc in encoded_train_data .encoders .values ()])
55
-
56
- # raw predictions for validation dataset
57
- args = {} if not is_classification else {"predict_proba" : True }
58
- filtered_df = encoded_val_data .data_frame
59
- normal_predictions = None
60
-
61
- if len (analysis_blocks ) > 0 :
62
- normal_predictions = predictor (encoded_val_data , args = PredictionArguments .from_dict (args ))
63
- normal_predictions = normal_predictions .set_index (encoded_val_data .data_frame .index )
64
-
65
- # ------------------------- #
66
- # Run analysis blocks, both core and user-defined
67
- # ------------------------- #
68
- kwargs = {
69
- 'predictor' : predictor ,
70
- 'target' : target ,
71
- 'input_cols' : input_cols ,
72
- 'dtype_dict' : dtype_dict ,
73
- 'normal_predictions' : normal_predictions ,
74
- 'data' : filtered_df ,
75
- 'train_data' : train_data ,
76
- 'encoded_val_data' : encoded_val_data ,
77
- 'is_classification' : is_classification ,
78
- 'is_numerical' : is_numerical ,
79
- 'is_multi_ts' : is_multi_ts ,
80
- 'stats_info' : stats_info ,
81
- 'tss' : tss ,
82
- 'ts_analysis' : ts_analysis ,
83
- 'accuracy_functions' : accuracy_functions ,
84
- 'has_pretrained_text_enc' : has_pretrained_text_enc
85
- }
86
-
87
- for block in analysis_blocks :
88
- log .info ("The block %s is now running its analyze() method" , block .__class__ .__name__ )
89
- runtime_analyzer = block .analyze (runtime_analyzer , ** kwargs )
51
+ if not pdef .embedding_only :
52
+ # predictive task
53
+ is_numerical = data_type in (dtype .integer , dtype .float , dtype .num_tsarray , dtype .quantity )
54
+ is_classification = data_type in (dtype .categorical , dtype .binary , dtype .cat_tsarray )
55
+ is_multi_ts = tss .is_timeseries and tss .horizon > 1
56
+ has_pretrained_text_enc = any ([isinstance (enc , PretrainedLangEncoder )
57
+ for enc in encoded_train_data .encoders .values ()])
58
+
59
+ # raw predictions for validation dataset
60
+ args = {} if not is_classification else {"predict_proba" : True }
61
+ normal_predictions = None
62
+
63
+ if len (analysis_blocks ) > 0 :
64
+ if tss .is_timeseries :
65
+ # we retrieve the first entry per group (closest to supervision cutoff)
66
+ if tss .group_by :
67
+ encoded_val_data .data_frame ['__mdb_val_idx' ] = np .arange (len (encoded_val_data ))
68
+ idxs = encoded_val_data .data_frame .groupby (by = tss .group_by ).first ()['__mdb_val_idx' ].values
69
+ encoded_val_data .data_frame = encoded_val_data .data_frame .iloc [idxs , :]
70
+ if encoded_val_data .cache_built :
71
+ encoded_val_data .X_cache = encoded_val_data .X_cache [idxs , :]
72
+ encoded_val_data .Y_cache = encoded_val_data .Y_cache [idxs , :]
73
+ normal_predictions = predictor (encoded_val_data , args = PredictionArguments .from_dict (args ))
74
+ normal_predictions = normal_predictions .set_index (encoded_val_data .data_frame .index )
75
+
76
+ # ------------------------- #
77
+ # Run analysis blocks, both core and user-defined
78
+ # ------------------------- #
79
+ kwargs = {
80
+ 'predictor' : predictor ,
81
+ 'target' : target ,
82
+ 'input_cols' : input_cols ,
83
+ 'dtype_dict' : dtype_dict ,
84
+ 'normal_predictions' : normal_predictions ,
85
+ 'data' : encoded_val_data .data_frame ,
86
+ 'train_data' : train_data ,
87
+ 'encoded_val_data' : encoded_val_data ,
88
+ 'is_classification' : is_classification ,
89
+ 'is_numerical' : is_numerical ,
90
+ 'is_multi_ts' : is_multi_ts ,
91
+ 'stats_info' : stats_info ,
92
+ 'tss' : tss ,
93
+ 'ts_analysis' : ts_analysis ,
94
+ 'accuracy_functions' : accuracy_functions ,
95
+ 'has_pretrained_text_enc' : has_pretrained_text_enc
96
+ }
97
+
98
+ for block in analysis_blocks :
99
+ log .info ("The block %s is now running its analyze() method" , block .__class__ .__name__ )
100
+ runtime_analyzer = block .analyze (runtime_analyzer , ** kwargs )
90
101
91
102
# ------------------------- #
92
103
# Populate ModelAnalysis object
0 commit comments