From 60a6818b762b1f229fc00b994f9dd1b5a40f3255 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Fri, 29 Nov 2024 18:47:07 +0100 Subject: [PATCH] fix bug --- src/scportrait/pipeline/classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scportrait/pipeline/classification.py b/src/scportrait/pipeline/classification.py index c7854c5f..14cb094f 100644 --- a/src/scportrait/pipeline/classification.py +++ b/src/scportrait/pipeline/classification.py @@ -403,7 +403,7 @@ def inference(self, #save the results for each batch into the memory mapped array at the specified indices features[ix:(ix+batch_size)] = result.numpy() - cell_ids[ix:(ix+batch_size)] = class_id.numpy().unsqueeze(1) + cell_ids[ix:(ix+batch_size)] = class_id.unsqueeze(1) for i in range(len(dataloader) - 1): if i % 10 == 0: @@ -415,7 +415,7 @@ def inference(self, #save the results for each batch into the memory mapped array at the specified indices features[ix:(ix+batch_size)] = r.cpu().detach().numpy() - cell_ids[ix:(ix+batch_size)] = label.numpy().unsqueeze(1) + cell_ids[ix:(ix+batch_size)] = label.unsqueeze(1) if hasattr(self.config, "log_transform"): if self.config["log_transform"]: