Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Nov 29, 2024
1 parent 9bca6e0 commit 60a6818
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/scportrait/pipeline/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def inference(self,

#save the results for each batch into the memory mapped array at the specified indices
features[ix:(ix+batch_size)] = result.numpy()
cell_ids[ix:(ix+batch_size)] = class_id.numpy().unsqueeze(1)
cell_ids[ix:(ix+batch_size)] = class_id.unsqueeze(1)

for i in range(len(dataloader) - 1):
if i % 10 == 0:
Expand All @@ -415,7 +415,7 @@ def inference(self,

#save the results for each batch into the memory mapped array at the specified indices
features[ix:(ix+batch_size)] = r.cpu().detach().numpy()
cell_ids[ix:(ix+batch_size)] = label.numpy().unsqueeze(1)
cell_ids[ix:(ix+batch_size)] = label.unsqueeze(1)

if hasattr(self.config, "log_transform"):
if self.config["log_transform"]:
Expand Down

0 comments on commit 60a6818

Please sign in to comment.