From 0865a5277fb84f2a7150ce003dad962465eee8cc Mon Sep 17 00:00:00 2001 From: Trent McConaghy Date: Fri, 5 Jul 2024 18:18:44 +0200 Subject: [PATCH] Fix #1351: Merge small improvements from PR #1281 (PR #1352) --- pdr_backend/aimodel/aimodel_factory.py | 2 ++ pdr_backend/aimodel/aimodel_plotdata.py | 7 ++++--- pdr_backend/aimodel/aimodel_plotter.py | 2 ++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pdr_backend/aimodel/aimodel_factory.py b/pdr_backend/aimodel/aimodel_factory.py index 556f568c6..883eb935f 100644 --- a/pdr_backend/aimodel/aimodel_factory.py +++ b/pdr_backend/aimodel/aimodel_factory.py @@ -80,6 +80,7 @@ def _build_wrapped_regr( ss = self.ss assert ss.do_regr assert ycont is not None + assert X.shape[0] == ycont.shape[0], (X.shape[0], ycont.shape[0]) do_constant = min(ycont) == max(ycont) or ss.approach == "RegrConstant" # weight newest sample 10x, and 2nd-newest sample 5x @@ -145,6 +146,7 @@ def _build_direct_classif( ) -> Aimodel: ss = self.ss assert not ss.do_regr + assert X.shape[0] == len(ytrue), (X.shape[0], len(ytrue)) n_True, n_False = sum(ytrue), sum(np.invert(ytrue)) smallest_n = min(n_True, n_False) do_constant = (smallest_n == 0) or ss.approach == "ClassifConstant" diff --git a/pdr_backend/aimodel/aimodel_plotdata.py b/pdr_backend/aimodel/aimodel_plotdata.py index ad111307d..112db74d6 100644 --- a/pdr_backend/aimodel/aimodel_plotdata.py +++ b/pdr_backend/aimodel/aimodel_plotdata.py @@ -20,8 +20,8 @@ def __init__( model: Aimodel, X_train: np.ndarray, ytrue_train: np.ndarray, - ycont_train: np.ndarray, - y_thr: float, + ycont_train: Optional[np.ndarray], + y_thr: Optional[float], colnames: List[str], slicing_x: np.ndarray, sweep_vars: Optional[List[int]] = None, @@ -45,7 +45,8 @@ def __init__( assert len(colnames) == n, (len(colnames), n) assert slicing_x.shape[0] == n, (slicing_x.shape[0], n) assert ytrue_train.shape[0] == N, (ytrue_train.shape[0], N) - assert ycont_train.shape[0] == N, (ycont_train.shape[0], N) + if ycont_train is not None: + assert ycont_train.shape[0] == N, (ycont_train.shape[0], N) assert sweep_vars is None or len(sweep_vars) in [1, 2] # set values diff --git a/pdr_backend/aimodel/aimodel_plotter.py b/pdr_backend/aimodel/aimodel_plotter.py index bbf043df1..6e980ddba 100644 --- a/pdr_backend/aimodel/aimodel_plotter.py +++ b/pdr_backend/aimodel/aimodel_plotter.py @@ -138,6 +138,8 @@ def _plot_lineplot_1var(aimodel_plotdata: AimodelPlotdata): # line plot: regressor response, training data if d.model.do_regr: assert mesh_ycont_hat is not None + assert y_thr is not None + assert ycont is not None fig.add_trace( go.Scatter( x=mesh_chosen_x,