From 4382e82c1c55bd2fc5a7254446f64529b1c59e3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sat, 30 Nov 2024 15:54:21 +0100 Subject: [PATCH] fix issues with saving results --- src/scportrait/pipeline/classification.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/scportrait/pipeline/classification.py b/src/scportrait/pipeline/classification.py index d1012b7d..0a6743af 100644 --- a/src/scportrait/pipeline/classification.py +++ b/src/scportrait/pipeline/classification.py @@ -397,25 +397,29 @@ def inference(self, features_path = tempmmap.create_empty_mmap(shape_features, dtype = np.float32) cell_ids_path = tempmmap.create_empty_mmap(shape_labels, dtype = np.int64) + labels_path = tempmmap.create_empty_mmap(shape_labels, dtype = np.int64) features = tempmmap.mmap_array_from_path(features_path) cell_ids = tempmmap.mmap_array_from_path(cell_ids_path) + labels = tempmmap.mmap_array_from_path(labels_path) #save the results for each batch into the memory mapped array at the specified indices features[ix:(ix+batch_size)] = result.numpy() cell_ids[ix:(ix+batch_size)] = class_id.unsqueeze(1) + labels[ix:(ix+batch_size)] = label.unsqueeze(1) ix += batch_size for i in range(len(dataloader) - 1): if i % 10 == 0: self.log(f"processing batch {i}") - x, label, id = next(data_iter) + x, label, class_id = next(data_iter) r = model_fun(x.to(self.config["inference_device"])) #save the results for each batch into the memory mapped array at the specified indices features[ix:(ix+r.shape[0])] = r.cpu().detach().numpy() - cell_ids[ix:(ix+r.shape[0])] = label.unsqueeze(1) + cell_ids[ix:(ix+r.shape[0])] = class_id.unsqueeze(1) + labels[ix:(ix+r.shape[0])] = label.unsqueeze(1) ix += r.shape[0] @@ -424,14 +428,11 @@ def inference(self, sigma = 1e-9 features = np.log(features + sigma) - label = label.numpy() - class_id = class_id.numpy() - # save inferred activations / predictions result_labels = [f"result_{i}" for i in range(features.shape[1])] dataframe = pd.DataFrame(data=features, columns=result_labels) - dataframe["label"] = label + dataframe["label"] = labels dataframe["cell_id"] = cell_ids.astype("int") self.log("finished processing")