Skip to content

Commit

Permalink
functional code to featurize cropped_rois
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Feb 16, 2025
1 parent 4bc7303 commit 613041c
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 61 deletions.
155 changes: 94 additions & 61 deletions src/scportrait/pipeline/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def _load_dataset(self):
def __call__( # type: ignore
self,
extraction_dir: str,
cropped_rois: bool = False,
partial: bool = False,
dataset_type=HDF5SingleCellDataset,
):
Expand Down Expand Up @@ -393,14 +394,24 @@ class based on the previous single-cell extraction. Therefore, only the second a

for encoder in self.encoders:
if encoder == "forward":
self.inference(dataloader, model.network.forward, partial=partial)
self.inference(
dataloader,
model.network.forward,
partial=partial,
cropped_rois=cropped_rois,
)
if encoder == "encoder":
self.inference(dataloader, model.network.encoder, partial=partial)
self.inference(
dataloader,
model.network.encoder,
partial=partial,
cropped_rois=cropped_rois,
)

# ensure all intermediate results are cleared after processing
self.clear_temp_dir()

def inference(self, dataloader, model_fun, partial=False):
def inference(self, dataloader, model_fun, partial=False, cropped_rois=False):
# 1. performs inference for a dataloader and a given network call
# 2. saves the results to file

Expand Down Expand Up @@ -465,14 +476,28 @@ def inference(self, dataloader, model_fun, partial=False):

self.log("finished processing")

if partial:
path = os.path.join(
self.run_path, f"partial_dimension_reduction_{model_fun.__name__}.csv"
)
if cropped_rois:
if partial:
path = os.path.join(
self.run_path,
f"cropped_rois_partial_dimension_reduction_{model_fun.__name__}.csv",
)
else:
path = os.path.join(
self.run_path,
f"cropped_rois_dimension_reduction_{model_fun.__name__}.csv",
)

else:
path = os.path.join(
self.run_path, f"dimension_reduction_{model_fun.__name__}.csv"
)
if partial:
path = os.path.join(
self.run_path,
f"partial_dimension_reduction_{model_fun.__name__}.csv",
)
else:
path = os.path.join(
self.run_path, f"dimension_reduction_{model_fun.__name__}.csv"
)

dataframe.to_csv(path)

Expand Down Expand Up @@ -561,7 +586,7 @@ def get_gpu_memory_usage(self):
print("Error:", e)
return None

def inference(self, dataloader, model_ensemble, partial=False):
def inference(self, dataloader, model_ensemble, partial=False, cropped_rois=False):
data_iter = iter(dataloader)
self.log(
f"Start processing {len(data_iter)} batches with {len(model_ensemble)} models from ensemble."
Expand Down Expand Up @@ -620,19 +645,32 @@ def inference(self, dataloader, model_ensemble, partial=False):
new_order = columns_to_move + other_columns
dataframe = dataframe[new_order]

if partial:
path = os.path.join(
self.directory, f"partial_ensemble_inference_{self.ensemble_name}.csv"
)
if cropped_rois:
if partial:
path = os.path.join(
self.directory,
f"cropped_rois_partial_ensemble_inference_{self.ensemble_name}.csv",
)
else:
path = os.path.join(
self.directory,
f"cropped_rois_ensemble_inference_{self.ensemble_name}.csv",
)
else:
path = os.path.join(
self.directory, f"ensemble_inference_{self.ensemble_name}.csv"
)
if partial:
path = os.path.join(
self.directory,
f"partial_ensemble_inference_{self.ensemble_name}.csv",
)
else:
path = os.path.join(
self.directory, f"ensemble_inference_{self.ensemble_name}.csv"
)
dataframe.to_csv(path, sep=",")

self.log(f"Results saved to file: {path}")

def __call__(self, extraction_dir, partial=False):
def __call__(self, extraction_dir, partial=False, cropped_rois=False):
"""
Function called to perform classification on the provided HDF5 dataset.
Expand Down Expand Up @@ -718,7 +756,10 @@ class based on the previous single-cell extraction. Therefore, no parameters nee

# perform inference
self.inference(
dataloader=dataloader, model_ensemble=model_ensemble, partial=partial
dataloader=dataloader,
model_ensemble=model_ensemble,
partial=partial,
cropped_rois=cropped_rois,
)


Expand Down Expand Up @@ -844,11 +885,10 @@ def is_Int(self, s):
def __call__(
self,
extraction_dir,
accessory,
size=0,
partial=False,
cropped_rois=False,
project_dataloader=HDF5SingleCellDataset,
accessory_dataloader=HDF5SingleCellDataset,
):
"""
Perform featurization on the provided HDF5 dataset.
Expand All @@ -857,14 +897,10 @@ def __call__(
----------
extraction_dir : str
Directory containing the extracted HDF5 files from the project. If this class is used as part of a project processing workflow this argument will be provided automatically.
accessory : list
List containing accessory datasets on which inference should be performed in addition to the cells contained within the current project.
size : int, optional, default=0
How many cells should be selected for inference. Default is 0, meaning all cells are selected.
project_dataloader : HDF5SingleCellDataset, optional
Dataloader for the project dataset. Default is HDF5SingleCellDataset.
accessory_dataloader : HDF5SingleCellDataset, optional
Dataloader for the accessory datasets. Default is HDF5SingleCellDataset.
Returns
-------
Expand All @@ -879,11 +915,7 @@ def __call__(
--------
.. code-block:: python
# Define accessory dataset: additional HDF5 datasets that you want to perform an inference on
# Leave empty if you only want to infer on all extracted cells in the current project
accessory = ([], [], [])
project.classify(accessory=accessory)
project.classify()
Notes
-----
Expand Down Expand Up @@ -912,10 +944,6 @@ def __call__(
self.log(f"starting with run {self.current_run}")
self.log(self.config)

accessory_sizes, accessory_labels, accessory_paths = accessory

self.log(f"{len(accessory_sizes)} different accessory datasets specified")

# Generate project dataset dataloader
t = transforms.Compose([])

Expand All @@ -933,22 +961,6 @@ def __call__(
residual = len(dataset) - size
dataset, _ = torch.utils.data.random_split(dataset, [size, residual])

# Load accessory dataset
for i in range(len(accessory_sizes)):
self.log(f"loading {accessory_paths[i]}")
with redirect_stdout(f):
local_dataset = HDF5SingleCellDataset(
[accessory_paths[i]], [i + 1], transform=t, return_fake_id=True
)

if len(local_dataset) > accessory_sizes[i]:
residual = len(local_dataset) - accessory_sizes[i]
local_dataset, _ = torch.utils.data.random_split(
local_dataset, [accessory_sizes[i], residual]
)

dataset = torch.utils.data.ConcatDataset([dataset, local_dataset])

# Log stdout
out = f.getvalue()
self.log(out)
Expand All @@ -960,7 +972,7 @@ def __call__(
num_workers=self.config["dataloader_worker_number"],
shuffle=False,
)
self.inference(dataloader, partial=partial)
self.inference(dataloader, partial=partial, cropped_rois=cropped_rois)

def calculate_statistics(self, img, channel=-1):
"""
Expand Down Expand Up @@ -1038,7 +1050,7 @@ def calculate_statistics(self, img, channel=-1):
)
return results

def inference(self, dataloader, partial=False):
def inference(self, dataloader, partial=False, cropped_rois=False):
"""
Perform inference for a dataloader and save the results to a file.
Expand Down Expand Up @@ -1095,10 +1107,20 @@ def inference(self, dataloader, partial=False):

self.log("finished processing")

if partial:
path = os.path.join(self.run_path, "partial_calculated_features.csv")
if cropped_rois:
if partial:
path = os.path.join(
self.run_path, "cropped_rois_partial_calculated_features.csv"
)
else:
path = os.path.join(
self.run_path, "cropped_rois_calculated_features.csv"
)
else:
path = os.path.join(self.run_path, "calculated_features.csv")
if partial:
path = os.path.join(self.run_path, "partial_calculated_features.csv")
else:
path = os.path.join(self.run_path, "calculated_features.csv")
dataframe.to_csv(path)


Expand Down Expand Up @@ -1278,6 +1300,7 @@ def __call__( # type: ignore
self,
extraction_dir: str,
partial: bool = False,
cropped_rois: bool = False,
dataset_type=HDF5SingleCellDataset,
):
"""
Expand Down Expand Up @@ -1352,9 +1375,9 @@ class based on the previous single-cell extraction. Therefore, only the second a
shuffle=False,
)

self.inference(dataloader, model, partial=partial)
self.inference(dataloader, model, partial=partial, cropped_rois=cropped_rois)

def inference(self, dataloader, model_fun, partial=False):
def inference(self, dataloader, model_fun, partial=False, cropped_rois=False):
# 1. performs inference for a dataloader and a given network call
# 2. saves the results to file

Expand Down Expand Up @@ -1420,10 +1443,20 @@ def inference(self, dataloader, model_fun, partial=False):

self.log("finished processing")

if partial:
path = os.path.join(self.run_path, "partial_featurization_ConvNeXt.csv")
if cropped_rois:
if partial:
path = os.path.join(
self.run_path, "cropped_roi_partial_featurization_ConvNeXt.csv"
)
else:
path = os.path.join(
self.run_path, "cropped_roi_featurization_ConvNeXt.csv"
)
else:
path = os.path.join(self.run_path, "featurization_ConvNeXt.csv")
if partial:
path = os.path.join(self.run_path, "partial_featurization_ConvNeXt.csv")
else:
path = os.path.join(self.run_path, "featurization_ConvNeXt.csv")

dataframe.to_csv(path)
self.clear_temp_dir() # ensure temp directories are deleted after processing
1 change: 1 addition & 0 deletions src/scportrait/pipeline/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,7 @@ def classify(self, cropped_rois=False, partial=False, *args, **kwargs):
self.classification_f(
f"{input_extraction}/{filename}",
partial=partial,
cropped_rois=cropped_rois,
*args,
**kwargs,
)
Expand Down

0 comments on commit 613041c

Please sign in to comment.