From 969353e836d12bd5c813cc37d32c030fd198b86c Mon Sep 17 00:00:00 2001 From: kausmees Date: Wed, 23 Jun 2021 12:03:26 +0200 Subject: [PATCH] Minor fixes to remove bug when running with no valid set. --- run_gcae.py | 5 ++++- utils/data_handler.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/run_gcae.py b/run_gcae.py index 652cdee..b6ae18d 100644 --- a/run_gcae.py +++ b/run_gcae.py @@ -670,7 +670,10 @@ def loss_func(y_pred, y_true): print("Model layers and dimensions:") print("-----------------------------") - output_valid_batch, encoded_data_valid_batch = autoencoder(input_valid[0:2], is_training = False, verbose = True) + input_test, targets_test, _ = dg.get_train_set(0.0) + if not missing_mask_input: + input_test = input_test[:,:,0, np.newaxis] + output_test, encoded_data_test = autoencoder(input_test[0:2], is_training = False, verbose = True) ######### Create objects for tensorboard summary ############################### diff --git a/utils/data_handler.py b/utils/data_handler.py index b1690bf..32b2850 100644 --- a/utils/data_handler.py +++ b/utils/data_handler.py @@ -78,7 +78,7 @@ def _define_samples(self): self.n_train_samples_orig = len(self.sample_idx_all) self.n_train_samples = self.n_train_samples_orig self.ind_pop_list_train_orig = ind_pop_list[self.sample_idx_all] - self.train_set_indices = np.array(range(self.n_train_samples)) + self.train_set_indices = np.arange(self.n_train_samples) self.n_valid_samples = 0 @@ -170,7 +170,10 @@ def define_validation_set(self, validation_split): _, _, self.sample_idx_train, self.sample_idx_valid = get_test_samples_stratified(self.genotypes_train_orig, self.ind_pop_list_train_orig, validation_split) - self.train_set_indices = np.array(range((len(self.sample_idx_train)))) + self.sample_idx_train = np.array(self.sample_idx_train) + self.sample_idx_valid = np.array(self.sample_idx_valid) + + self.train_set_indices = np.array(range(len(self.sample_idx_train))) self.n_valid_samples = len(self.sample_idx_valid) self.n_train_samples = len(self.sample_idx_train) @@ -184,8 +187,14 @@ def get_valid_set(self, sparsify): target_data_valid (n_valid_samples x n_markers): original validation genotypes ind_pop_list valid (n_valid_samples x 2) : individual and population IDs of validation samples + or + empty arrays if no valid set defined + ''' + if self.n_valid_samples == 0: + return np.array([]), np.array([]), np.array([]) + # n_valid_samples x n_markers x 2 input_data_valid = np.full((len(self.sample_idx_valid),self.genotypes_train_orig.shape[1],2),1.0)