Skip to content

Commit

Permalink
refactor(scar): refactor scar to allow efficient usage of GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed Jul 29, 2024
1 parent 6b101f1 commit c846d9b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def train(
training_set, batch_size=batch_size, shuffle=shuffle,
drop_last=True
)
self.dataset = training_set
# val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, list_ids=test_ids)
# val_generator = torch.utils.data.DataLoader(
# val_set, batch_size=batch_size, shuffle=shuffle
Expand Down Expand Up @@ -593,7 +594,7 @@ def inference(
native_frequencies, and noise_ratio. \
A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type.
"""
total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, cache_capacity=self.cache_capacity)
# total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, cache_capacity=self.cache_capacity)
n_features = self.n_features
sample_size = self.raw_count.shape[0]
self.native_counts = np.empty([sample_size, n_features])
Expand All @@ -605,7 +606,7 @@ def inference(
batch_size = sample_size
i = 0
generator_full_data = torch.utils.data.DataLoader(
total_set, batch_size=batch_size, shuffle=False
self.dataset, batch_size=batch_size, shuffle=False
)

for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data:
Expand Down

0 comments on commit c846d9b

Please sign in to comment.