diff --git a/examples/code_snippets/segmentation/run_seg_any_image.py b/examples/code_snippets/segmentation/run_seg_any_image.py index 9634acf4..dd4e9881 100644 --- a/examples/code_snippets/segmentation/run_seg_any_image.py +++ b/examples/code_snippets/segmentation/run_seg_any_image.py @@ -35,14 +35,8 @@ method = CytosolOnlySegmentationCellpose(config=config) method.config = method.config["CytosolOnlySegmentationCellpose"] -# create datastructure to save results to -method.maps = {} - # perform segmentation -method.cellpose_segmentation(image) - -# access results -seg_mask = method.maps["cytosol_segmentation"] +seg_mask = method.cellpose_segmentation(image) # plot results plt.imshow(seg_mask) diff --git a/src/scportrait/__init__.py b/src/scportrait/__init__.py index 513894ab..bf0b2d57 100644 --- a/src/scportrait/__init__.py +++ b/src/scportrait/__init__.py @@ -11,6 +11,10 @@ from scportrait import processing as pp from scportrait import tools as tl +# Python 3.12 is more strict about escape sequencing in string literals +# mahotas: https://github.com/luispedro/mahotas/issues/151 +warnings.filterwarnings("ignore", category=SyntaxWarning, message="invalid escape sequence") + # silence warning from spatialdata resulting in an older dask version see #139 warnings.filterwarnings("ignore", message="ignoring keyword argument 'read_only'") diff --git a/src/scportrait/pipeline/_utils/sdata_io.py b/src/scportrait/pipeline/_utils/sdata_io.py index 6e757251..2bf7aae1 100644 --- a/src/scportrait/pipeline/_utils/sdata_io.py +++ b/src/scportrait/pipeline/_utils/sdata_io.py @@ -242,7 +242,6 @@ def _write_segmentation_object_sdata( self, segmentation_object: Labels2DModel, segmentation_label: str, - classes: set[str] | None = None, overwrite: bool = False, ) -> None: """Write segmentation object to SpatialData. @@ -268,7 +267,6 @@ def _write_segmentation_sdata( self, segmentation: xarray.DataArray | np.ndarray, segmentation_label: str, - classes: set[str] | None = None, chunks: ChunkSize2D = (1000, 1000), overwrite: bool = False, ) -> None: @@ -292,7 +290,7 @@ def _write_segmentation_sdata( if not get_chunk_size(mask) == chunks: mask.data = mask.data.rechunk(chunks) - self._write_segmentation_object_sdata(mask, segmentation_label, classes=classes, overwrite=overwrite) + self._write_segmentation_object_sdata(mask, segmentation_label, overwrite=overwrite) def _write_points_object_sdata(self, points: PointsModel, points_name: str, overwrite: bool = False) -> None: """Write points object to SpatialData. diff --git a/src/scportrait/pipeline/_utils/segmentation.py b/src/scportrait/pipeline/_utils/segmentation.py index 7ec1bc7c..f626d908 100644 --- a/src/scportrait/pipeline/_utils/segmentation.py +++ b/src/scportrait/pipeline/_utils/segmentation.py @@ -671,3 +671,21 @@ def sc_all(array: NDArray) -> bool: if not x: return False return True + + +def remap_mask(input_mask: np.ndarray) -> np.ndarray: + # Create lookup table as an array + max_label = np.max(input_mask) + lookup_array = np.zeros(max_label + 1, dtype=np.int32) + + cell_ids = np.unique(input_mask)[1:] + lookup_table = dict(zip(cell_ids, range(1, len(cell_ids)), strict=True)) + + # Populate lookup array based on the dictionary + for old_id, new_id in lookup_table.items(): + lookup_array[old_id] = new_id + + # Apply the mapping using NumPy indexing + remapped_mask = lookup_array[input_mask] + + return remapped_mask diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index e394cb7a..15ba6574 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -272,6 +272,9 @@ def _setup_segmentation_f(self, segmentation_f): from_project=True, ) + def _update_segmentation_f(self, segmentation_f): + self._setup_segmentation_f(segmentation_f) + def _setup_extraction_f(self, extraction_f): """Configure the extraction method for the project. @@ -332,6 +335,25 @@ def _setup_featurization_f(self, featurization_f): from_project=True, ) + def update_featurization_f(self, featurization_f): + """Update the featurization method chosen for the project without reinitializing the entire project. + + Args: + featurization_f : The featurization method that should be used for the project. + + Returns: + None : the featurization method is updated in the project object. + + Examples: + Update the featurization method for a project:: + + from scportrait.pipeline.featurization import CellFeaturizer + + project.update_featurization_f(CellFeaturizer) + """ + self.log(f"Replacing current featurization method {self.featurization_f.__class__} with {featurization_f}") + self._setup_featurization_f(featurization_f) + def _setup_selection(self, selection_f): """Configure the selection method for the project. @@ -360,25 +382,6 @@ def _setup_selection(self, selection_f): from_project=True, ) - def update_featurization_f(self, featurization_f): - """Update the featurization method chosen for the project without reinitializing the entire project. - - Args: - featurization_f : The featurization method that should be used for the project. - - Returns: - None : the featurization method is updated in the project object. - - Examples: - Update the featurization method for a project:: - - from scportrait.pipeline.featurization import CellFeaturizer - - project.update_featurization_f(CellFeaturizer) - """ - self.log(f"Replacing current featurization method {self.featurization_f.__class__} with {featurization_f}") - self._setup_featurization_f(featurization_f) - ##### General small helper functions #### def _check_memory(self, item): diff --git a/src/scportrait/pipeline/segmentation/segmentation.py b/src/scportrait/pipeline/segmentation/segmentation.py index a981f25e..0e471012 100644 --- a/src/scportrait/pipeline/segmentation/segmentation.py +++ b/src/scportrait/pipeline/segmentation/segmentation.py @@ -18,6 +18,7 @@ from PIL import Image from tqdm.auto import tqdm +from scportrait.io.daskmmap import dask_array_from_path from scportrait.pipeline._base import ProcessingStep from scportrait.pipeline._utils.segmentation import _return_edge_labels, sc_any, shift_labels @@ -119,7 +120,6 @@ def __init__( self.is_shard = False # additional parameters to configure level of debugging for developers - self.deep_debug = False self.save_filter_results = False self.nuc_seg_name = nuc_seg_name self.cyto_seg_name = cyto_seg_name @@ -301,6 +301,24 @@ def _save_segmentation(self, labels: np.array, classes: list) -> None: self.log("=== Finished segmentation of shard ===") + def _save_segmentation_sdata_from_memmap(self, temp_file_path, masks=None): + if masks is None: + masks = ["nuclei", "cytosol"] + + # connect to the temp file as a dask array + labels = dask_array_from_path(temp_file_path) + + if "nuclei" in masks: + ix = masks.index("nuclei") + + self.filehandler._write_segmentation_sdata(labels[ix], self.nuc_seg_name, overwrite=self.overwrite) + self.filehandler._add_centers(self.nuc_seg_name, overwrite=self.overwrite) + + if "cytosol" in masks: + ix = masks.index("cytosol") + self.filehandler._write_segmentation_sdata(labels[ix], self.cyto_seg_name, overwrite=self.overwrite) + self.filehandler._add_centers(self.cyto_seg_name, overwrite=self.overwrite) + def _save_segmentation_sdata(self, labels, classes, masks=None): if masks is None: masks = ["nuclei", "cytosol"] @@ -310,16 +328,12 @@ def _save_segmentation_sdata(self, labels, classes, masks=None): if "nuclei" in masks: ix = masks.index("nuclei") - self.filehandler._write_segmentation_sdata( - labels[ix], self.nuc_seg_name, classes=classes, overwrite=self.overwrite - ) + self.filehandler._write_segmentation_sdata(labels[ix], self.nuc_seg_name, overwrite=self.overwrite) self.filehandler._add_centers(self.nuc_seg_name, overwrite=self.overwrite) if "cytosol" in masks: ix = masks.index("cytosol") - self.filehandler._write_segmentation_sdata( - labels[ix], self.cyto_seg_name, classes=classes, overwrite=self.overwrite - ) + self.filehandler._write_segmentation_sdata(labels[ix], self.cyto_seg_name, overwrite=self.overwrite) self.filehandler._add_centers(self.cyto_seg_name, overwrite=self.overwrite) def save_map(self, map_name): @@ -422,7 +436,6 @@ def _call_as_shard(self): input_image = input_image[ :, self.window[0], self.window[1] ] # for some segmentation workflows potentially only the first channel is required this is further selected down in that segmentation workflow - self.input_image = input_image # track for potential plotting of intermediate results if self.deep_debug: self.log( @@ -907,14 +920,12 @@ def _resolve_sharding(self, sharding_plan): # save newly generated class list self._save_classes(list(filtered_classes_combined)) - # ensure cleanup - self.clear_temp_dir() - self.log("resolved sharding plan.") + del hdf_labels # ensure that memory is freed up # save final segmentation to sdata - self._save_segmentation_sdata(hdf_labels, list(filtered_classes_combined), masks=self.method.MASK_NAMES) - + self._save_segmentation_sdata_from_memmap(hdf_labels_path, masks=self.method.MASK_NAMES) + self.clear_temp_dir() # ensure cleanup self.log("finished saving segmentation results to sdata object for sharded segmentation.") if not self.deep_debug: diff --git a/src/scportrait/pipeline/segmentation/workflows.py b/src/scportrait/pipeline/segmentation/workflows.py index 4218a2aa..cb095990 100644 --- a/src/scportrait/pipeline/segmentation/workflows.py +++ b/src/scportrait/pipeline/segmentation/workflows.py @@ -34,6 +34,8 @@ class _BaseSegmentation(Segmentation): + MASK_NAMES = ["nucleus", "cytosol"] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._setup_channel_selection() @@ -75,7 +77,7 @@ def _setup_maximum_intensity_projection(self): def _define_channels_to_extract_for_segmentation(self): self.segmentation_channels = [] - if "nuclei" in self.MASK_NAMES: + if "nucleus" in self.MASK_NAMES: if "segmentation_channels_nuclei" in self.config.keys(): self.nucleus_segmentation_channel = self.config["segmentation_channels_nuclei"] elif "combine_nucleus_channels" in self.config.keys(): @@ -126,7 +128,7 @@ def _transform_input_image(self, input_image): values = [] # check if any channels need to be transformed - if "nuclei" in self.MASK_NAMES: + if "nucleus" in self.MASK_NAMES: if self.maximum_project_nucleus: self.log( f"For nucleus segmentation using the maximum intensity projection of channels {self.original_combine_nucleus_channels}." @@ -177,7 +179,7 @@ def return_empty_mask(self, input_image): _, x, y = input_image.shape self._save_segmentation_sdata(np.zeros((self.N_MASKS, x, y)), []) - def _check_seg_dtype(self, mask: np.array, mask_name: str) -> np.array: + def _check_seg_dtype(self, mask: np.ndarray, mask_name: str) -> np.ndarray: if not isinstance(mask, self.DEFAULT_SEGMENTATION_DTYPE): Warning( f"{mask_name} segmentation map is not of the correct dtype. \n Forcefully converting {mask.dtype} to {self.DEFAULT_SEGMENTATION_DTYPE}. \n This could lead to unexpected behaviour." @@ -481,19 +483,25 @@ def _median_correct_image(self, input_image, median_filter_size: int, debug: boo ##### Filtering Functions ##### # 1. Size Filtering - def _check_for_size_filtering(self, mask_types=None) -> None: + def _check_for_size_filtering(self, mask_types: list[str]) -> None: """ Check if size filtering should be performed on the masks. If size filtering is turned on, the thresholds for filtering are loaded from the config file. """ - if mask_types is None: - mask_types = ["nucleus", "cytosol"] + assert all( + mask_type in self.MASK_NAMES for mask_type in mask_types + ), f"mask_types must be a list of strings that are valid mask names {self.MASK_NAMES}." + if "filter_masks_size" in self.config.keys(): self.filter_size = self.config["filter_masks_size"] else: # default behaviour is this is turned off filtering can always be performed later and this preserves the whole segmentation mask self.filter_size = False + for mask_type in mask_types: + # save attributes for use later + setattr(self, f"{mask_type}_thresholds", None) + setattr(self, f"{mask_type}_confidence_interval", None) # load parameters for cellsize filtering if self.filter_size: @@ -604,10 +612,6 @@ def _perform_size_filtering( if plot_results: # get input image for visualization - if input_image is None: - if "input_image" in self.__dict__.keys(): - input_image = self.input_image - if input_image is not None: if len(input_image.shape) == 2: image_map = input_image @@ -717,10 +721,6 @@ def _perform_mask_matching_filtering( plot_results = True if plot_results: - if input_image is not None: - if "input_image" in self.__dict__.keys(): - input_image = self.input_image - if input_image is not None: # convert input image from uint16 to uint8 input_image = (input_image / 256).astype(np.uint8) @@ -1146,7 +1146,7 @@ def _cytosol_segmentation(self, input_image, debug: bool = False): class WGASegmentation(_ClassicalSegmentation): N_MASKS = 2 N_INPUT_CHANNELS = 2 - MASK_NAMES = ["nuclei", "cytosol"] + MASK_NAMES = ["nucleus", "cytosol"] DEFAULT_NUCLEI_CHANNEL_IDS = [0] DEFAULT_CYTOSOL_CHANNEL_IDS = [1] @@ -1220,7 +1220,7 @@ def __init__(self, *args, **kwargs): class DAPISegmentation(_ClassicalSegmentation): N_MASKS = 1 N_INPUT_CHANNELS = 1 - MASK_NAMES = ["nuclei"] + MASK_NAMES = ["nucleus"] DEFAULT_NUCLEI_CHANNEL_IDS = [0] def __init__(self, *args, **kwargs): @@ -1317,7 +1317,7 @@ def _read_cellpose_model(self, modeltype: str, name: str, gpu: str, device) -> m model = models.CellposeModel(pretrained_model=name, gpu=gpu, device=device) return model - def _load_model(self, model_type: str, gpu: str, device) -> tuple[float, models.Cellpose]: + def _load_model(self, model_type: str, gpu: str, device) -> models.Cellpose: """ Loads cellpose model @@ -1377,7 +1377,7 @@ def _load_model(self, model_type: str, gpu: str, device) -> tuple[float, models. return model - def _check_input_image_dtype(self, input_image): + def _check_input_image_dtype(self, input_image: np.ndarray): if input_image.dtype != self.DEFAULT_IMAGE_DTYPE: if isinstance(input_image.dtype, int): ValueError( @@ -1445,27 +1445,25 @@ def _check_gpu_status(self): class DAPISegmentationCellpose(_CellposeSegmentation): N_MASKS = 1 N_INPUT_CHANNELS = 1 - MASK_NAMES = ["nuclei"] + MASK_NAMES = ["nucleus"] DEFAULT_NUCLEI_CHANNEL_IDS = [0] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _setup_filtering(self): - self._check_for_size_filtering(mask_types=["nucleus"]) + self._check_for_size_filtering(mask_types=self.MASK_NAMES) - def _finalize_segmentation_results(self): + def _finalize_segmentation_results(self, nucleus_mask: np.ndarray) -> np.ndarray: # ensure correct dtype of the maps - self.maps["nucleus_segmentation"] = self._check_seg_dtype( - mask=self.maps["nucleus_segmentation"], mask_name="nucleus" - ) + nucleus_mask = self._check_seg_dtype(mask=nucleus_mask, mask_name="nucleus") - segmentation = np.stack([self.maps["nucleus_segmentation"]]) + segmentation = np.stack([nucleus_mask]) return segmentation - def cellpose_segmentation(self, input_image): + def cellpose_segmentation(self, input_image: np.ndarray) -> np.ndarray: self._check_gpu_status() self._clear_cache() # ensure we start with an empty cache @@ -1489,7 +1487,7 @@ def cellpose_segmentation(self, input_image): cellprob_threshold=self.cellprob_threshold, channels=[1, 0], )[0] - masks = np.array(masks) # convert to array + masks = np.array(masks) # ensure all edge classes are removed masks = remove_edge_labels(masks) @@ -1500,18 +1498,15 @@ def cellpose_segmentation(self, input_image): if self.filter_size: masks = self._perform_size_filtering( mask=masks, - thresholds=self.nucleus_thresholds, - confidence_interval=self.nucleus_confidence_interval, + thresholds=self.nucleus_thresholds, # type: ignore + confidence_interval=self.nucleus_confidence_interval, # type: ignore mask_name="nucleus", log=True, input_image=input_image if self.debug else None, ) - # save segementation to maps for access from other subfunctions - self.maps["nucleus_segmentation"] = masks.reshape(masks.shape[1:]) - - # manually delete model and perform gc to free up memory on GPU - self._clear_cache(vars_to_delete=[model, masks]) + masks = masks.reshape(masks.shape[1:]) + return masks def _execute_segmentation(self, input_image): total_time_start = timeit.default_timer() @@ -1531,14 +1526,14 @@ def _execute_segmentation(self, input_image): } start_segmentation = timeit.default_timer() - self.cellpose_segmentation(input_image) + nucleus_mask = self.cellpose_segmentation(input_image) stop_segmentation = timeit.default_timer() self.segmentation_time = stop_segmentation - start_segmentation # finalize classes list - all_classes = set(np.unique(self.maps["nucleus_segmentation"])) - {0} + all_classes = set(np.unique(nucleus_mask)) - {0} - segmentation = self._finalize_segmentation_results() + segmentation = self._finalize_segmentation_results(nucleus_mask=nucleus_mask) self._save_segmentation_sdata(segmentation, all_classes, masks=self.MASK_NAMES) self.total_time = timeit.default_timer() - total_time_start @@ -1550,32 +1545,27 @@ class ShardedDAPISegmentationCellpose(ShardedSegmentation): class CytosolSegmentationCellpose(_CellposeSegmentation): N_MASKS = 2 N_INPUT_CHANNELS = 2 - MASK_NAMES = ["nuclei", "cytosol"] + MASK_NAMES = ["nucleus", "cytosol"] DEFAULT_NUCLEI_CHANNEL_IDS = [0] DEFAULT_CYTOSOL_CHANNEL_IDS = [1] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def _finalize_segmentation_results(self): + def _finalize_segmentation_results(self, mask_nucleus: np.ndarray, mask_cytosol: np.ndarray) -> np.ndarray: # ensure correct dtype of maps - self.maps["nucleus_segmentation"] = self._check_seg_dtype( - mask=self.maps["nucleus_segmentation"], mask_name="nucleus" - ) - self.maps["cytosol_segmentation"] = self._check_seg_dtype( - mask=self.maps["cytosol_segmentation"], mask_name="cytosol" - ) - - segmentation = np.stack([self.maps["nucleus_segmentation"], self.maps["cytosol_segmentation"]]) + mask_nucleus = self._check_seg_dtype(mask=mask_nucleus, mask_name="nucleus") + mask_cytosol = self._check_seg_dtype(mask=mask_cytosol, mask_name="cytosol") + segmentation = np.stack([mask_nucleus, mask_cytosol]) return segmentation def _setup_filtering(self): - self._check_for_size_filtering(mask_types=["nucleus", "cytosol"]) + self._check_for_size_filtering(mask_types=self.MASK_NAMES) self._check_for_mask_matching_filtering() - def cellpose_segmentation(self, input_image): + def cellpose_segmentation(self, input_image: np.ndarray) -> tuple[np.ndarray, np.ndarray]: self._check_gpu_status() self._clear_cache() # ensure we start with an empty cache @@ -1640,8 +1630,8 @@ def cellpose_segmentation(self, input_image): if self.filter_size: masks_nucleus = self._perform_size_filtering( mask=masks_nucleus, - thresholds=self.nucleus_thresholds, - confidence_interval=self.nucleus_confidence_interval, + thresholds=self.nucleus_thresholds, # type: ignore + confidence_interval=self.nucleus_confidence_interval, # type: ignore mask_name="nucleus", log=True, debug=self.debug, @@ -1650,8 +1640,8 @@ def cellpose_segmentation(self, input_image): masks_cytosol = self._perform_size_filtering( mask=masks_cytosol, - thresholds=self.nucleus_thresholds, - confidence_interval=self.nucleus_confidence_interval, + thresholds=self.nucleus_thresholds, # type: ignore + confidence_interval=self.nucleus_confidence_interval, # type: ignore mask_name="cytosol", log=True, debug=self.debug, @@ -1680,12 +1670,10 @@ def cellpose_segmentation(self, input_image): ### Cleanup Generated Segmentation masks ###################### - # first when the masks are finalized save them to the maps - self.maps["nucleus_segmentation"] = masks_nucleus.reshape(masks_nucleus.shape[1:]) + masks_nucleus = masks_nucleus.reshape(masks_nucleus.shape[1:]) + masks_cytosol = masks_cytosol.reshape(masks_cytosol.shape[1:]) - self.maps["cytosol_segmentation"] = masks_cytosol.reshape(masks_cytosol.shape[1:]) - - self._clear_cache(vars_to_delete=[masks_nucleus, masks_cytosol]) + return (masks_nucleus, masks_cytosol) def _execute_segmentation(self, input_image): total_time_start = timeit.default_timer() @@ -1696,29 +1684,15 @@ def _execute_segmentation(self, input_image): # check image dtype since cellpose expects int input images self._check_input_image_dtype(input_image) - # initialize location to save masks to - self.maps = { - "nucleus_segmentation": tempmmap.array( - shape=(1, input_image.shape[1], input_image.shape[2]), - dtype=self.DEFAULT_SEGMENTATION_DTYPE, - tmp_dir_abs_path=self._tmp_dir_path, - ), - "cytosol_segmentation": tempmmap.array( - shape=(1, input_image.shape[1], input_image.shape[2]), - dtype=self.DEFAULT_SEGMENTATION_DTYPE, - tmp_dir_abs_path=self._tmp_dir_path, - ), - } - start_segmentation = timeit.default_timer() - self.cellpose_segmentation(input_image) + masks_nucleus, masks_cytosol = self.cellpose_segmentation(input_image) stop_segmentation = timeit.default_timer() self.segmentation_time = stop_segmentation - start_segmentation # finalize segmentation classes ensuring that background is removed - all_classes = set(np.unique(self.maps["nucleus_segmentation"])) - {0} + all_classes = set(np.unique(masks_nucleus)) - {0} - segmentation = self._finalize_segmentation_results() + segmentation = self._finalize_segmentation_results(mask_nucleus=masks_nucleus, mask_cytosol=masks_cytosol) self._save_segmentation_sdata(segmentation, all_classes, masks=self.MASK_NAMES) # clean up memory @@ -1734,25 +1708,15 @@ class CytosolSegmentationDownsamplingCellpose(CytosolSegmentationCellpose): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def _finalize_segmentation_results(self): - self.maps["fullsize_nucleus_segmentation"] = self._rescale_downsampled_mask( - self.maps["nucleus_segmentation"], "nucleus_segmentation" - ) - self.maps["fullsize_cytosol_segmentation"] = self._rescale_downsampled_mask( - self.maps["cytosol_segmentation"], "cytosol_segmentation" - ) + def _finalize_segmentation_results(self, mask_nucleus: np.ndarray, mask_cytosol: np.ndarray) -> np.ndarray: + mask_nucleus = self._rescale_downsampled_mask(mask_nucleus, "nucleus_segmentation") + mask_cytosol = self._rescale_downsampled_mask(mask_cytosol, "cytosol_segmentation") - self.maps["fullsize_nucleus_segmentation"] = self._check_seg_dtype( - mask=self.maps["fullsize_nucleus_segmentation"], mask_name="nucleus" - ) - self.maps["fullsize_cytosol_segmentation"] = self._check_seg_dtype( - mask=self.maps["fullsize_cytosol_segmentation"], mask_name="cytosol" - ) + mask_nucleus = self._check_seg_dtype(mask=mask_nucleus, mask_name="nucleus") + mask_cytosol = self._check_seg_dtype(mask=mask_cytosol, mask_name="cytosol") # combine masks into one stack - segmentation = np.stack( - [self.maps["fullsize_nucleus_segmentation"], self.maps["fullsize_cytosol_segmentation"]] - ) + segmentation = np.stack([mask_nucleus, mask_cytosol]) return segmentation @@ -1775,39 +1739,15 @@ def _execute_segmentation(self, input_image): # downsample the image input_image = self._downsample_image(input_image) - # setup the memory mapped arrays to store the results - self.maps = { - "nucleus_segmentation": tempmmap.array( - shape=(1, input_image.shape[1], input_image.shape[2]), - dtype=self.DEFAULT_SEGMENTATION_DTYPE, - tmp_dir_abs_path=self._tmp_dir_path, - ), - "cytosol_segmentation": tempmmap.array( - shape=(1, input_image.shape[1], input_image.shape[2]), - dtype=self.DEFAULT_SEGMENTATION_DTYPE, - tmp_dir_abs_path=self._tmp_dir_path, - ), - "fullsize_nucleus_segmentation": tempmmap.array( - shape=(1, self.original_image_size[1], self.original_image_size[2]), - dtype=self.DEFAULT_SEGMENTATION_DTYPE, - tmp_dir_abs_path=self._tmp_dir_path, - ), - "fullsize_cytosol_segmentation": tempmmap.array( - shape=(1, self.original_image_size[1], self.original_image_size[2]), - dtype=self.DEFAULT_SEGMENTATION_DTYPE, - tmp_dir_abs_path=self._tmp_dir_path, - ), - } - start_segmentation = timeit.default_timer() - self.cellpose_segmentation(input_image) + mask_nucleus, mask_cytosol = self.cellpose_segmentation(input_image) stop_segmentation = timeit.default_timer() self.segmentation_time = stop_segmentation - start_segmentation # finalize classes list - all_classes = set(np.unique(self.maps["nucleus_segmentation"])) - {0} + all_classes = set(np.unique(mask_nucleus)) - {0} - segmentation = self._finalize_segmentation_results() + segmentation = self._finalize_segmentation_results(mask_nucleus=mask_nucleus, mask_cytosol=mask_cytosol) self._save_segmentation_sdata(segmentation, all_classes, masks=self.MASK_NAMES) self._clear_cache(vars_to_delete=[segmentation, all_classes]) @@ -1828,19 +1768,16 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _setup_filtering(self): - self._check_for_size_filtering(mask_types=["cytosol"]) + self._check_for_size_filtering(mask_types=self.MASK_NAMES) - def _finalize_segmentation_results(self): + def _finalize_segmentation_results(self, cytosol_mask: np.ndarray) -> np.ndarray: # ensure correct dtype of maps - self.maps["cytosol_segmentation"] = self._check_seg_dtype( - mask=self.maps["cytosol_segmentation"], mask_name="cytosol" - ) - - segmentation = np.stack([self.maps["cytosol_segmentation"]]) + cytosol_mask = self._check_seg_dtype(mask=cytosol_mask, mask_name="cytosol") + segmentation = np.stack([cytosol_mask]) return segmentation - def cellpose_segmentation(self, input_image): + def cellpose_segmentation(self, input_image: np.ndarray) -> np.ndarray: self._setup_processing() self._clear_cache() @@ -1881,20 +1818,17 @@ def cellpose_segmentation(self, input_image): if self.filter_size: masks_cytosol = self._perform_size_filtering( mask=masks_cytosol, - thresholds=self.nucleus_thresholds, - confidence_interval=self.nucleus_confidence_interval, + thresholds=self.nucleus_thresholds, # type: ignore + confidence_interval=self.nucleus_confidence_interval, # type: ignore mask_name="cytosol", log=True, debug=self.debug, input_image=input_image if self.debug else None, ) - self.maps["cytosol_segmentation"] = masks_cytosol.reshape( - masks_cytosol.shape[1:] - ) # add reshape to match shape to HDF5 shape + masks_cytosol = masks_cytosol.reshape(masks_cytosol.shape[1:]) # add reshape to match shape to HDF5 shape - # clear memory - self._clear_cache(vars_to_delete=[masks_cytosol]) + return masks_cytosol def _execute_segmentation(self, input_image) -> None: total_time_start = timeit.default_timer() @@ -1905,25 +1839,16 @@ def _execute_segmentation(self, input_image) -> None: # check image dtype since cellpose expects int input images self._check_input_image_dtype(input_image) - # initialize location to save masks to - self.maps = { - "cytosol_segmentation": tempmmap.array( - shape=(1, input_image.shape[1], input_image.shape[2]), - dtype=self.DEFAULT_SEGMENTATION_DTYPE, - tmp_dir_abs_path=self._tmp_dir_path, - ), - } - # execute segmentation start_segmentation = timeit.default_timer() - self.cellpose_segmentation(input_image) + cytosol_mask = self.cellpose_segmentation(input_image) stop_segmentation = timeit.default_timer() self.segmentation_time = stop_segmentation - start_segmentation # get final classes list - all_classes = set(np.unique(self.maps["cytosol_segmentation"])) - {0} + all_classes = set(np.unique(cytosol_mask)) - {0} - segmentation = self._finalize_segmentation_results() + segmentation = self._finalize_segmentation_results(cytosol_mask) self._save_segmentation_sdata(segmentation, all_classes, masks=self.MASK_NAMES) # clean up memory @@ -1942,17 +1867,13 @@ class CytosolOnlySegmentationDownsamplingCellpose(CytosolOnlySegmentationCellpos def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def _finalize_segmentation_results(self): - self.maps["fullsize_cytosol_segmentation"] = self._rescale_downsampled_mask( - self.maps["cytosol_segmentation"], "cytosol_segmentation" - ) + def _finalize_segmentation_results(self, cytosol_mask: np.ndarray) -> np.ndarray: + cytosol_mask = self._rescale_downsampled_mask(cytosol_mask, "cytosol_segmentation") - self.maps["fullsize_cytosol_segmentation"] = self._check_seg_dtype( - mask=self.maps["fullsize_cytosol_segmentation"], mask_name="cytosol" - ) + cytosol_mask = self._check_seg_dtype(mask=cytosol_mask, mask_name="cytosol") # combine masks into one stack - segmentation = np.stack([self.maps["fullsize_cytosol_segmentation"]]) + segmentation = np.stack([cytosol_mask]) return segmentation @@ -1976,29 +1897,15 @@ def _execute_segmentation(self, input_image) -> None: # downsample the image input_image = self._downsample_image(input_image) - # setup the memory mapped arrays to store the results - self.maps = { - "cytosol_segmentation": tempmmap.array( - shape=(1, input_image.shape[1], input_image.shape[2]), - dtype=self.DEFAULT_SEGMENTATION_DTYPE, - tmp_dir_abs_path=self._tmp_dir_path, - ), - "fullsize_cytosol_segmentation": tempmmap.array( - shape=(1, self.original_image_size[1], self.original_image_size[2]), - dtype=self.DEFAULT_SEGMENTATION_DTYPE, - tmp_dir_abs_path=self._tmp_dir_path, - ), - } - start_segmentation = timeit.default_timer() - self.cellpose_segmentation(input_image) + cytosol_mask = self.cellpose_segmentation(input_image) stop_segmentation = timeit.default_timer() self.segmentation_time = stop_segmentation - start_segmentation # currently no implemented filtering steps to remove nuclei outside of specific thresholds - all_classes = set(np.unique(self.maps["cytosol_segmentation"])) - {0} + all_classes = set(np.unique(cytosol_mask)) - {0} - segmentation = self._finalize_segmentation_results() # type: ignore + segmentation = self._finalize_segmentation_results(cytosol_mask=cytosol_mask) # type: ignore self._save_segmentation_sdata(segmentation, all_classes, masks=self.MASK_NAMES) self.total_time = timeit.default_timer() - total_time_start