Skip to content

Commit

Permalink
Minor fixes to remove bug when running with no valid set.
Browse files Browse the repository at this point in the history
  • Loading branch information
kausmees committed Jun 23, 2021
1 parent 5d5fe12 commit 969353e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
5 changes: 4 additions & 1 deletion run_gcae.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,10 @@ def loss_func(y_pred, y_true):
print("Model layers and dimensions:")
print("-----------------------------")

output_valid_batch, encoded_data_valid_batch = autoencoder(input_valid[0:2], is_training = False, verbose = True)
input_test, targets_test, _ = dg.get_train_set(0.0)
if not missing_mask_input:
input_test = input_test[:,:,0, np.newaxis]
output_test, encoded_data_test = autoencoder(input_test[0:2], is_training = False, verbose = True)

######### Create objects for tensorboard summary ###############################

Expand Down
13 changes: 11 additions & 2 deletions utils/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _define_samples(self):
self.n_train_samples_orig = len(self.sample_idx_all)
self.n_train_samples = self.n_train_samples_orig
self.ind_pop_list_train_orig = ind_pop_list[self.sample_idx_all]
self.train_set_indices = np.array(range(self.n_train_samples))
self.train_set_indices = np.arange(self.n_train_samples)

self.n_valid_samples = 0

Expand Down Expand Up @@ -170,7 +170,10 @@ def define_validation_set(self, validation_split):

_, _, self.sample_idx_train, self.sample_idx_valid = get_test_samples_stratified(self.genotypes_train_orig, self.ind_pop_list_train_orig, validation_split)

self.train_set_indices = np.array(range((len(self.sample_idx_train))))
self.sample_idx_train = np.array(self.sample_idx_train)
self.sample_idx_valid = np.array(self.sample_idx_valid)

self.train_set_indices = np.array(range(len(self.sample_idx_train)))
self.n_valid_samples = len(self.sample_idx_valid)
self.n_train_samples = len(self.sample_idx_train)

Expand All @@ -184,8 +187,14 @@ def get_valid_set(self, sparsify):
target_data_valid (n_valid_samples x n_markers): original validation genotypes
ind_pop_list valid (n_valid_samples x 2) : individual and population IDs of validation samples
or
empty arrays if no valid set defined
'''

if self.n_valid_samples == 0:
return np.array([]), np.array([]), np.array([])

# n_valid_samples x n_markers x 2
input_data_valid = np.full((len(self.sample_idx_valid),self.genotypes_train_orig.shape[1],2),1.0)

Expand Down

0 comments on commit 969353e

Please sign in to comment.