Skip to content

Commit

Permalink
Merge branch 'main' into fix_sharding_test
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Feb 24, 2025
2 parents 77b15dc + 15a6968 commit e52eff0
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 212 deletions.
8 changes: 1 addition & 7 deletions examples/code_snippets/segmentation/run_seg_any_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,8 @@
method = CytosolOnlySegmentationCellpose(config=config)
method.config = method.config["CytosolOnlySegmentationCellpose"]

# create datastructure to save results to
method.maps = {}

# perform segmentation
method.cellpose_segmentation(image)

# access results
seg_mask = method.maps["cytosol_segmentation"]
seg_mask = method.cellpose_segmentation(image)

# plot results
plt.imshow(seg_mask)
Expand Down
4 changes: 4 additions & 0 deletions src/scportrait/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from scportrait import processing as pp
from scportrait import tools as tl

# Python 3.12 is more strict about escape sequencing in string literals
# mahotas: https://github.com/luispedro/mahotas/issues/151
warnings.filterwarnings("ignore", category=SyntaxWarning, message="invalid escape sequence")

# silence warning from spatialdata resulting in an older dask version see #139
warnings.filterwarnings("ignore", message="ignoring keyword argument 'read_only'")

Expand Down
4 changes: 1 addition & 3 deletions src/scportrait/pipeline/_utils/sdata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def _write_segmentation_object_sdata(
self,
segmentation_object: Labels2DModel,
segmentation_label: str,
classes: set[str] | None = None,
overwrite: bool = False,
) -> None:
"""Write segmentation object to SpatialData.
Expand All @@ -268,7 +267,6 @@ def _write_segmentation_sdata(
self,
segmentation: xarray.DataArray | np.ndarray,
segmentation_label: str,
classes: set[str] | None = None,
chunks: ChunkSize2D = (1000, 1000),
overwrite: bool = False,
) -> None:
Expand All @@ -292,7 +290,7 @@ def _write_segmentation_sdata(
if not get_chunk_size(mask) == chunks:
mask.data = mask.data.rechunk(chunks)

self._write_segmentation_object_sdata(mask, segmentation_label, classes=classes, overwrite=overwrite)
self._write_segmentation_object_sdata(mask, segmentation_label, overwrite=overwrite)

def _write_points_object_sdata(self, points: PointsModel, points_name: str, overwrite: bool = False) -> None:
"""Write points object to SpatialData.
Expand Down
18 changes: 18 additions & 0 deletions src/scportrait/pipeline/_utils/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,3 +671,21 @@ def sc_all(array: NDArray) -> bool:
if not x:
return False
return True


def remap_mask(input_mask: np.ndarray) -> np.ndarray:
# Create lookup table as an array
max_label = np.max(input_mask)
lookup_array = np.zeros(max_label + 1, dtype=np.int32)

cell_ids = np.unique(input_mask)[1:]
lookup_table = dict(zip(cell_ids, range(1, len(cell_ids)), strict=True))

# Populate lookup array based on the dictionary
for old_id, new_id in lookup_table.items():
lookup_array[old_id] = new_id

# Apply the mapping using NumPy indexing
remapped_mask = lookup_array[input_mask]

return remapped_mask
41 changes: 22 additions & 19 deletions src/scportrait/pipeline/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ def _setup_segmentation_f(self, segmentation_f):
from_project=True,
)

def _update_segmentation_f(self, segmentation_f):
self._setup_segmentation_f(segmentation_f)

def _setup_extraction_f(self, extraction_f):
"""Configure the extraction method for the project.
Expand Down Expand Up @@ -332,6 +335,25 @@ def _setup_featurization_f(self, featurization_f):
from_project=True,
)

def update_featurization_f(self, featurization_f):
"""Update the featurization method chosen for the project without reinitializing the entire project.
Args:
featurization_f : The featurization method that should be used for the project.
Returns:
None : the featurization method is updated in the project object.
Examples:
Update the featurization method for a project::
from scportrait.pipeline.featurization import CellFeaturizer
project.update_featurization_f(CellFeaturizer)
"""
self.log(f"Replacing current featurization method {self.featurization_f.__class__} with {featurization_f}")
self._setup_featurization_f(featurization_f)

def _setup_selection(self, selection_f):
"""Configure the selection method for the project.
Expand Down Expand Up @@ -360,25 +382,6 @@ def _setup_selection(self, selection_f):
from_project=True,
)

def update_featurization_f(self, featurization_f):
"""Update the featurization method chosen for the project without reinitializing the entire project.
Args:
featurization_f : The featurization method that should be used for the project.
Returns:
None : the featurization method is updated in the project object.
Examples:
Update the featurization method for a project::
from scportrait.pipeline.featurization import CellFeaturizer
project.update_featurization_f(CellFeaturizer)
"""
self.log(f"Replacing current featurization method {self.featurization_f.__class__} with {featurization_f}")
self._setup_featurization_f(featurization_f)

##### General small helper functions ####

def _check_memory(self, item):
Expand Down
37 changes: 24 additions & 13 deletions src/scportrait/pipeline/segmentation/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from PIL import Image
from tqdm.auto import tqdm

from scportrait.io.daskmmap import dask_array_from_path
from scportrait.pipeline._base import ProcessingStep
from scportrait.pipeline._utils.segmentation import _return_edge_labels, sc_any, shift_labels

Expand Down Expand Up @@ -119,7 +120,6 @@ def __init__(
self.is_shard = False

# additional parameters to configure level of debugging for developers
self.deep_debug = False
self.save_filter_results = False
self.nuc_seg_name = nuc_seg_name
self.cyto_seg_name = cyto_seg_name
Expand Down Expand Up @@ -301,6 +301,24 @@ def _save_segmentation(self, labels: np.array, classes: list) -> None:

self.log("=== Finished segmentation of shard ===")

def _save_segmentation_sdata_from_memmap(self, temp_file_path, masks=None):
if masks is None:
masks = ["nuclei", "cytosol"]

# connect to the temp file as a dask array
labels = dask_array_from_path(temp_file_path)

if "nuclei" in masks:
ix = masks.index("nuclei")

self.filehandler._write_segmentation_sdata(labels[ix], self.nuc_seg_name, overwrite=self.overwrite)
self.filehandler._add_centers(self.nuc_seg_name, overwrite=self.overwrite)

if "cytosol" in masks:
ix = masks.index("cytosol")
self.filehandler._write_segmentation_sdata(labels[ix], self.cyto_seg_name, overwrite=self.overwrite)
self.filehandler._add_centers(self.cyto_seg_name, overwrite=self.overwrite)

def _save_segmentation_sdata(self, labels, classes, masks=None):
if masks is None:
masks = ["nuclei", "cytosol"]
Expand All @@ -310,16 +328,12 @@ def _save_segmentation_sdata(self, labels, classes, masks=None):
if "nuclei" in masks:
ix = masks.index("nuclei")

self.filehandler._write_segmentation_sdata(
labels[ix], self.nuc_seg_name, classes=classes, overwrite=self.overwrite
)
self.filehandler._write_segmentation_sdata(labels[ix], self.nuc_seg_name, overwrite=self.overwrite)
self.filehandler._add_centers(self.nuc_seg_name, overwrite=self.overwrite)

if "cytosol" in masks:
ix = masks.index("cytosol")
self.filehandler._write_segmentation_sdata(
labels[ix], self.cyto_seg_name, classes=classes, overwrite=self.overwrite
)
self.filehandler._write_segmentation_sdata(labels[ix], self.cyto_seg_name, overwrite=self.overwrite)
self.filehandler._add_centers(self.cyto_seg_name, overwrite=self.overwrite)

def save_map(self, map_name):
Expand Down Expand Up @@ -422,7 +436,6 @@ def _call_as_shard(self):
input_image = input_image[
:, self.window[0], self.window[1]
] # for some segmentation workflows potentially only the first channel is required this is further selected down in that segmentation workflow
self.input_image = input_image # track for potential plotting of intermediate results

if self.deep_debug:
self.log(
Expand Down Expand Up @@ -907,14 +920,12 @@ def _resolve_sharding(self, sharding_plan):
# save newly generated class list
self._save_classes(list(filtered_classes_combined))

# ensure cleanup
self.clear_temp_dir()

self.log("resolved sharding plan.")
del hdf_labels # ensure that memory is freed up

# save final segmentation to sdata
self._save_segmentation_sdata(hdf_labels, list(filtered_classes_combined), masks=self.method.MASK_NAMES)

self._save_segmentation_sdata_from_memmap(hdf_labels_path, masks=self.method.MASK_NAMES)
self.clear_temp_dir() # ensure cleanup
self.log("finished saving segmentation results to sdata object for sharded segmentation.")

if not self.deep_debug:
Expand Down
Loading

0 comments on commit e52eff0

Please sign in to comment.