diff --git a/scar/main/_scar.py b/scar/main/_scar.py index 3c17772..f6b63bb 100644 --- a/scar/main/_scar.py +++ b/scar/main/_scar.py @@ -450,6 +450,7 @@ def train( training_set, batch_size=batch_size, shuffle=shuffle, drop_last=True ) + self.dataset = training_set # val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=test_ids) # val_generator = torch.utils.data.DataLoader( # val_set, batch_size=batch_size, shuffle=shuffle @@ -593,7 +594,7 @@ def inference( native_frequencies, and noise_ratio. \ A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type. """ - total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, cache_capacity=self.cache_capacity) + # total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, cache_capacity=self.cache_capacity) n_features = self.n_features sample_size = self.raw_count.shape[0] self.native_counts = np.empty([sample_size, n_features]) @@ -605,7 +606,7 @@ def inference( batch_size = sample_size i = 0 generator_full_data = torch.utils.data.DataLoader( - total_set, batch_size=batch_size, shuffle=False + self.dataset, batch_size=batch_size, shuffle=False ) for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data: