From afae503ea4f5d89bdcf4d7f98797ae0f43022411 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 14 May 2024 16:47:47 +0800 Subject: [PATCH 01/52] Fixes #7557 Add a function to create a JSON file that maps input and output paths. Signed-off-by: staydelight --- monai/data/image_reader.py | 98 ++++++++++++++++++++++++++++---------- monai/data/image_writer.py | 28 ++++++++++- 2 files changed, 99 insertions(+), 27 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index f5e199e2a3..fa2c63b2e3 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -11,9 +11,11 @@ from __future__ import annotations +import json +import logging +import sys import glob import os -import re import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence @@ -21,6 +23,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from monai.apps.utils import get_logger import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -51,6 +54,16 @@ pydicom, has_pydicom = optional_import("pydicom") nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) +DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s" + +logger = get_logger(module_name=__name__, fmt=DEFAULT_FMT) +logger = logging.getLogger(__name__) +handler = logging.StreamHandler(sys.stdout) +handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) +logger.addHandler(handler) +logger.setLevel(logging.DEBUG) + + __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] @@ -98,8 +111,10 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | kwargs: additional args for actual `read` API of 3rd party libs. """ + #self.update_json(input_file=data) raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + @abstractmethod def get_data(self, img) -> tuple[np.ndarray, dict]: """ @@ -147,6 +162,24 @@ def _stack_images(image_list: list, meta_dict: dict): meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 return np.stack(image_list, axis=0) +def update_json(input_file=None, output_file=None): + record_path = "img-label.json" + + if not os.path.exists(record_path) or os.stat(record_path).st_size == 0: + with open(record_path, 'w') as f: + json.dump([], f) + + with open(record_path, 'r+') as f: + records = json.load(f) + if input_file: + new_record = {"image": input_file, "label": []} + records.append(new_record) + elif output_file and records: + records[-1]["label"].append(output_file) + + f.seek(0) + json.dump(records, f, indent=4) + @require_pkg(pkg_name="itk") class ITKReader(ImageReader): @@ -168,8 +201,8 @@ class ITKReader(ImageReader): series_name: the name of the DICOM series if there are multiple ones. used when loading DICOM series. reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. - If ``False``, the spatial indexing convention is reversed to be compatible with ITK; - otherwise, the spatial indexing follows the numpy convention. Default is ``False``. + If ``False``, the spatial indexing follows the numpy convention; + otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``. This option does not affect the metadata. series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice). This flag is checked only when loading DICOM series. Default is ``False``. @@ -225,6 +258,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_ = [] filenames: Sequence[PathLike] = ensure_tuple(data) + update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -332,6 +366,25 @@ def _get_affine(self, img, lps_to_ras: bool = True): affine[:sr, -1] = origin[:sr] if lps_to_ras: affine = orientation_ras_lps(affine) + logger.debug("lps is changed to ras") + + # 使用 Logger 輸出信息 + + logger.info("\nOrigin[:sr]:") + logger.info(", ".join(f"{x:.10f}" for x in origin[:sr])) + + logger.info("\nDirection[:sr, :sr]:") + for row in direction[:sr, :sr]: + logger.info(", ".join(f"{x:.15f}" for x in row)) + + logger.info("\nSpacing[:sr]:") + logger.info(", ".join(f"{x:.15f}" for x in spacing[:sr])) + + + # affine = numpy.round(affine, decimals=5) + + logger.debug(f"Affine matrix:\n{affine}") + return affine def _get_spatial_shape(self, img): @@ -404,12 +457,8 @@ class PydicomReader(ImageReader): label_dict: label of the dicom data. If provided, it will be used when loading segmentation data. Keys of the dict are the classes, and values are the corresponding class number. For example: for TCIA collection "C4KC-KiTS", it can be: {"Kidney": 0, "Renal Tumor": 1}. - fname_regex: a regular expression to match the file names when the input is a folder. - If provided, only the matched files will be included. For example, to include the file name - "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`. - Set it to `None` to use `pydicom.misc.is_dicom` to match valid files. kwargs: additional args for `pydicom.dcmread` API. more details about available args: - https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html + https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html#pydicom.filereader.dcmread If the `get_data` function will be called (for example, when using this reader with `monai.transforms.LoadImage`), please ensure that the argument `stop_before_pixels` is `True`, and `specific_tags` covers all necessary tags, such as `PixelSpacing`, @@ -423,7 +472,6 @@ def __init__( swap_ij: bool = True, prune_metadata: bool = True, label_dict: dict | None = None, - fname_regex: str = "", **kwargs, ): super().__init__() @@ -433,7 +481,6 @@ def __init__( self.swap_ij = swap_ij self.prune_metadata = prune_metadata self.label_dict = label_dict - self.fname_regex = fname_regex def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ @@ -465,6 +512,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_ = [] filenames: Sequence[PathLike] = ensure_tuple(data) + update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) @@ -474,16 +522,9 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): name = f"{name}" if Path(name).is_dir(): # read DICOM series - if self.fname_regex is not None: - series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if re.match(self.fname_regex, slc)] - else: - series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)] - slices = [] - for slc in series_slcs: - try: - slices.append(pydicom.dcmread(fp=slc, **kwargs_)) - except pydicom.errors.InvalidDicomError as e: - warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2) + series_slcs = glob.glob(os.path.join(name, "*")) + series_slcs = [slc for slc in series_slcs if "LICENSE" not in slc] + slices = [pydicom.dcmread(fp=slc, **kwargs_) for slc in series_slcs] img_.append(slices if len(slices) > 1 else slices[0]) if len(slices) > 1: self.has_series = True @@ -913,9 +954,11 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py """ + logger.info(f"Reading NIfTI data from: {data}") img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) + update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -1076,13 +1119,14 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) + update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: img = np.load(name, allow_pickle=True, **kwargs_) if Path(name).name.endswith(".npz"): # load expected items from NPZ file - npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys + npz_keys = [f"arr_{i}" for i in range(len(img))] if self.npz_keys is None else self.npz_keys for k in npz_keys: img_.append(img[k]) else: @@ -1173,6 +1217,7 @@ def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): img_: list[PILImage.Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) + update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -1297,10 +1342,11 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | """ img_: list = [] filenames: Sequence[PathLike] = ensure_tuple(data) + update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: - nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, **kwargs_)) + nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, *kwargs_)) img_.append(nrrd_image) return img_ if len(filenames) > 1 else img_[0] @@ -1323,7 +1369,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: header = dict(i.header) if self.index_order == "C": header = self._convert_f_to_c_order(header) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) if self.affine_lps_to_ras: header = self._switch_lps_ras(header) @@ -1344,7 +1390,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: return _stack_images(img_array, compatible_meta), compatible_meta - def _get_affine(self, header: dict) -> np.ndarray: + def _get_affine(self, img: NrrdImage) -> np.ndarray: """ Get the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. @@ -1353,8 +1399,8 @@ def _get_affine(self, header: dict) -> np.ndarray: img: A `NrrdImage` loaded from image file """ - direction = header["space directions"] - origin = header["space origin"] + direction = img.header["space directions"] + origin = img.header["space origin"] x, y = direction.shape affine_diam = min(x, y) + 1 diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index b9e8b9e68e..06209c664a 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING, Any, cast import numpy as np +import os +import json from monai.apps.utils import get_logger from monai.config import DtypeLike, NdarrayOrTensor, PathLike @@ -196,6 +198,25 @@ def write(self, filename: PathLike, verbose: bool = True, **kwargs): if verbose: logger.info(f"writing: {filename}") + def update_json(self, input_file=None, output_file=None): + record_path = "img-label.json" + + if not os.path.exists(record_path) or os.stat(record_path).st_size == 0: + with open(record_path, 'w') as f: + json.dump([], f) + + with open(record_path, 'r+') as f: + records = json.load(f) + if input_file: + new_record = {"image": input_file, "label": []} + records.append(new_record) + elif output_file and records: + records[-1]["label"].append(output_file) + + f.seek(0) + json.dump(records, f, indent=4) + + @classmethod def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray: """ @@ -276,7 +297,7 @@ def resample_if_needed( # convert back at the end if isinstance(output_array, MetaTensor): output_array.applied_operations = [] - data_array, *_ = convert_data_type(output_array, output_type=orig_type) + data_array, *_ = convert_data_type(output_array, output_type=orig_type) # type: ignore affine, *_ = convert_data_type(output_array.affine, output_type=orig_type) # type: ignore return data_array[0], affine @@ -462,7 +483,9 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L809 """ + logger.info(f"ITKWriter is processing the file: {filename}") super().write(filename, verbose=verbose) + super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( cast(NdarrayOrTensor, self.data_obj), channel_dim=self.channel_dim, @@ -625,7 +648,9 @@ def write(self, filename: PathLike, verbose: bool = False, **obj_kwargs): - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.save """ + logger.info(f"NibabelWriter is processing the file: {filename}") super().write(filename, verbose=verbose) + super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( cast(NdarrayOrTensor, self.data_obj), affine=self.affine, dtype=self.output_dtype, **obj_kwargs ) @@ -771,6 +796,7 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): - https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.save """ super().write(filename, verbose=verbose) + super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( data_array=self.data_obj, dtype=self.output_dtype, From 542a77d5cee53c15dc6c17fef7961ec528e5a299 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 14 May 2024 17:33:04 +0800 Subject: [PATCH 02/52] Fixes #7557 Remove changes unrelated to this issue. Signed-off-by: staydelight --- monai/data/image_reader.py | 1476 ++++++++++++++++++++++++++++++++++-- 1 file changed, 1426 insertions(+), 50 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index fa2c63b2e3..257bebc831 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -11,11 +11,10 @@ from __future__ import annotations -import json -import logging -import sys import glob +import json import os +import re import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence @@ -23,7 +22,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from monai.apps.utils import get_logger import numpy as np from torch.utils.data._utils.collate import np_str_obj_array_pattern @@ -54,16 +52,6 @@ pydicom, has_pydicom = optional_import("pydicom") nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) -DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s" - -logger = get_logger(module_name=__name__, fmt=DEFAULT_FMT) -logger = logging.getLogger(__name__) -handler = logging.StreamHandler(sys.stdout) -handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) -logger.addHandler(handler) -logger.setLevel(logging.DEBUG) - - __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] @@ -111,10 +99,8 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | kwargs: additional args for actual `read` API of 3rd party libs. """ - #self.update_json(input_file=data) raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - @abstractmethod def get_data(self, img) -> tuple[np.ndarray, dict]: """ @@ -161,7 +147,8 @@ def _stack_images(image_list: list, meta_dict: dict): # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 return np.stack(image_list, axis=0) - + + def update_json(input_file=None, output_file=None): record_path = "img-label.json" @@ -201,8 +188,8 @@ class ITKReader(ImageReader): series_name: the name of the DICOM series if there are multiple ones. used when loading DICOM series. reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. - If ``False``, the spatial indexing follows the numpy convention; - otherwise, the spatial indexing convention is reversed to be compatible with ITK. Default is ``False``. + If ``False``, the spatial indexing convention is reversed to be compatible with ITK; + otherwise, the spatial indexing follows the numpy convention. Default is ``False``. This option does not affect the metadata. series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice). This flag is checked only when loading DICOM series. Default is ``False``. @@ -366,25 +353,6 @@ def _get_affine(self, img, lps_to_ras: bool = True): affine[:sr, -1] = origin[:sr] if lps_to_ras: affine = orientation_ras_lps(affine) - logger.debug("lps is changed to ras") - - # 使用 Logger 輸出信息 - - logger.info("\nOrigin[:sr]:") - logger.info(", ".join(f"{x:.10f}" for x in origin[:sr])) - - logger.info("\nDirection[:sr, :sr]:") - for row in direction[:sr, :sr]: - logger.info(", ".join(f"{x:.15f}" for x in row)) - - logger.info("\nSpacing[:sr]:") - logger.info(", ".join(f"{x:.15f}" for x in spacing[:sr])) - - - # affine = numpy.round(affine, decimals=5) - - logger.debug(f"Affine matrix:\n{affine}") - return affine def _get_spatial_shape(self, img): @@ -457,8 +425,12 @@ class PydicomReader(ImageReader): label_dict: label of the dicom data. If provided, it will be used when loading segmentation data. Keys of the dict are the classes, and values are the corresponding class number. For example: for TCIA collection "C4KC-KiTS", it can be: {"Kidney": 0, "Renal Tumor": 1}. + fname_regex: a regular expression to match the file names when the input is a folder. + If provided, only the matched files will be included. For example, to include the file name + "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`. + Set it to `None` to use `pydicom.misc.is_dicom` to match valid files. kwargs: additional args for `pydicom.dcmread` API. more details about available args: - https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html#pydicom.filereader.dcmread + https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html If the `get_data` function will be called (for example, when using this reader with `monai.transforms.LoadImage`), please ensure that the argument `stop_before_pixels` is `True`, and `specific_tags` covers all necessary tags, such as `PixelSpacing`, @@ -472,6 +444,7 @@ def __init__( swap_ij: bool = True, prune_metadata: bool = True, label_dict: dict | None = None, + fname_regex: str = "", **kwargs, ): super().__init__() @@ -481,6 +454,7 @@ def __init__( self.swap_ij = swap_ij self.prune_metadata = prune_metadata self.label_dict = label_dict + self.fname_regex = fname_regex def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ @@ -522,9 +496,16 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): name = f"{name}" if Path(name).is_dir(): # read DICOM series - series_slcs = glob.glob(os.path.join(name, "*")) - series_slcs = [slc for slc in series_slcs if "LICENSE" not in slc] - slices = [pydicom.dcmread(fp=slc, **kwargs_) for slc in series_slcs] + if self.fname_regex is not None: + series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if re.match(self.fname_regex, slc)] + else: + series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)] + slices = [] + for slc in series_slcs: + try: + slices.append(pydicom.dcmread(fp=slc, **kwargs_)) + except pydicom.errors.InvalidDicomError as e: + warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2) img_.append(slices if len(slices) > 1 else slices[0]) if len(slices) > 1: self.has_series = True @@ -954,7 +935,6 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py """ - logger.info(f"Reading NIfTI data from: {data}") img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) @@ -1126,7 +1106,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img = np.load(name, allow_pickle=True, **kwargs_) if Path(name).name.endswith(".npz"): # load expected items from NPZ file - npz_keys = [f"arr_{i}" for i in range(len(img))] if self.npz_keys is None else self.npz_keys + npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys for k in npz_keys: img_.append(img[k]) else: @@ -1346,7 +1326,7 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: - nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, *kwargs_)) + nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, **kwargs_)) img_.append(nrrd_image) return img_ if len(filenames) > 1 else img_[0] @@ -1369,7 +1349,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: header = dict(i.header) if self.index_order == "C": header = self._convert_f_to_c_order(header) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header) if self.affine_lps_to_ras: header = self._switch_lps_ras(header) @@ -1390,7 +1370,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: return _stack_images(img_array, compatible_meta), compatible_meta - def _get_affine(self, img: NrrdImage) -> np.ndarray: + def _get_affine(self, header: dict) -> np.ndarray: """ Get the affine matrix of the image, it can be used to correct spacing, orientation or execute spatial transforms. @@ -1399,8 +1379,8 @@ def _get_affine(self, img: NrrdImage) -> np.ndarray: img: A `NrrdImage` loaded from image file """ - direction = img.header["space directions"] - origin = img.header["space origin"] + direction = header["space directions"] + origin = header["space origin"] x, y = direction.shape affine_diam = min(x, y) + 1 @@ -1440,4 +1420,1400 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header + return header# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import glob +import os +import re +import warnings +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Iterator, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +from torch.utils.data._utils.collate import np_str_obj_array_pattern + +from monai.config import KeysCollection, PathLike +from monai.data.utils import ( + affine_to_spacing, + correct_nifti_header_if_necessary, + is_no_channel, + is_supported_format, + orientation_ras_lps, +) +from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg + +if TYPE_CHECKING: + import itk + import nibabel as nib + import nrrd + import pydicom + from nibabel.nifti1 import Nifti1Image + from PIL import Image as PILImage + + has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True +else: + itk, has_itk = optional_import("itk", allow_namespace_pkg=True) + nib, has_nib = optional_import("nibabel") + Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") + PILImage, has_pil = optional_import("PIL.Image") + pydicom, has_pydicom = optional_import("pydicom") + nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) + +__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] + + +class ImageReader(ABC): + """ + An abstract class defines APIs to load image files. + + Typical usage of an implementation of this class is: + + .. code-block:: python + + image_reader = MyImageReader() + img_obj = image_reader.read(path_to_image) + img_data, meta_data = image_reader.get_data(img_obj) + + - The `read` call converts image filenames into image objects, + - The `get_data` call fetches the image data, as well as metadata. + - A reader should implement `verify_suffix` with the logic of checking the input filename + by the filename extensions. + + """ + + @abstractmethod + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified `filename` is supported by the current reader. + This method should return True if the reader is able to read the format suggested by the + `filename`. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any: + """ + Read image data from specified file or files. + Note that it returns a data object or a sequence of data objects. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for actual `read` API of 3rd party libs. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + @abstractmethod + def get_data(self, img) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function must return two objects, the first is a numpy array of image data, + the second is a dictionary of metadata. + + Args: + img: an image object loaded from an image file or a list of image objects. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +def _copy_compatible_dict(from_dict: dict, to_dict: dict): + if not isinstance(to_dict, dict): + raise ValueError(f"to_dict must be a Dict, got {type(to_dict)}.") + if not to_dict: + for key in from_dict: + datum = from_dict[key] + if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None: + continue + to_dict[key] = str(TraceKeys.NONE) if datum is None else datum # NoneType to string for default_collate + else: + affine_key, shape_key = MetaKeys.AFFINE, MetaKeys.SPATIAL_SHAPE + if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]): + raise RuntimeError( + "affine matrix of all images should be the same for channel-wise concatenation. " + f"Got {from_dict[affine_key]} and {to_dict[affine_key]}." + ) + if shape_key in from_dict and not np.allclose(from_dict[shape_key], to_dict[shape_key]): + raise RuntimeError( + "spatial_shape of all images should be the same for channel-wise concatenation. " + f"Got {from_dict[shape_key]} and {to_dict[shape_key]}." + ) + + +def _stack_images(image_list: list, meta_dict: dict): + if len(image_list) <= 1: + return image_list[0] + if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): + channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) + return np.concatenate(image_list, axis=channel_dim) + # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified + meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 + return np.stack(image_list, axis=0) + + +@require_pkg(pkg_name="itk") +class ITKReader(ImageReader): + """ + Load medical images based on ITK library. + All the supported image formats can be found at: + https://github.com/InsightSoftwareConsortium/ITK/tree/master/Modules/IO + The loaded data array will be in C order, for example, a 3D image NumPy + array index order will be `CDWH`. + + Args: + channel_dim: the channel dimension of the input image, default is None. + This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. + If None, `original_channel_dim` will be either `no_channel` or `-1`. + + - Nifti file is usually "channel last", so there is no need to specify this argument. + - PNG file usually has `GetNumberOfComponentsPerPixel()==3`, so there is no need to specify this argument. + + series_name: the name of the DICOM series if there are multiple ones. + used when loading DICOM series. + reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. + If ``False``, the spatial indexing convention is reversed to be compatible with ITK; + otherwise, the spatial indexing follows the numpy convention. Default is ``False``. + This option does not affect the metadata. + series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice). + This flag is checked only when loading DICOM series. Default is ``False``. + affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. + Set to ``True`` to be consistent with ``NibabelReader``, otherwise the affine matrix remains in the ITK convention. + kwargs: additional args for `itk.imread` API. more details about available args: + https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py + + """ + + def __init__( + self, + channel_dim: str | int | None = None, + series_name: str = "", + reverse_indexing: bool = False, + series_meta: bool = False, + affine_lps_to_ras: bool = True, + **kwargs, + ): + super().__init__() + self.kwargs = kwargs + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim + self.series_name = series_name + self.reverse_indexing = reverse_indexing + self.series_meta = series_meta + self.affine_lps_to_ras = affine_lps_to_ras + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified file or files format is supported by ITK reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + + """ + return has_itk + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + If passing directory path instead of file path, will treat it as DICOM images series and read. + Note that the returned object is ITK image object or list of ITK image objects. + + Args: + data: file name or a list of file names to read, + kwargs: additional args for `itk.imread` API, will override `self.kwargs` for existing keys. + More details about available args: + https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py + + """ + img_ = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + name = f"{name}" + if Path(name).is_dir(): + # read DICOM series + # https://examples.itk.org/src/io/gdcm/readdicomseriesandwrite3dimage/documentation + names_generator = itk.GDCMSeriesFileNames.New() + names_generator.SetUseSeriesDetails(True) + names_generator.AddSeriesRestriction("0008|0021") # Series Date + names_generator.SetDirectory(name) + series_uid = names_generator.GetSeriesUIDs() + + if len(series_uid) < 1: + raise FileNotFoundError(f"no DICOMs in: {name}.") + if len(series_uid) > 1: + warnings.warn(f"the directory: {name} contains more than one DICOM series.") + series_identifier = series_uid[0] if not self.series_name else self.series_name + name = names_generator.GetFileNames(series_identifier) + + name = name[0] if len(name) == 1 else name # type: ignore + _obj = itk.imread(name, **kwargs_) + if self.series_meta: + _reader = itk.ImageSeriesReader.New(FileNames=name) + _reader.Update() + _meta = _reader.GetMetaDataDictionaryArray() + if len(_meta) > 0: + # TODO: using the first slice's meta. this could be improved to filter unnecessary tags. + _obj.SetMetaDataDictionary(_meta[0]) + img_.append(_obj) + else: + img_.append(itk.imread(name, **kwargs_)) + return img_ if len(filenames) > 1 else img_[0] + + def get_data(self, img) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + When loading a list of files, they are stacked together at a new dimension as the first dimension, + and the metadata of the first image is used to represent the output metadata. + + Args: + img: an ITK image object loaded from an image file or a list of ITK image objects. + + """ + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + + for i in ensure_tuple(img): + data = self._get_array_data(i) + img_array.append(data) + header = self._get_meta_dict(i) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i, self.affine_lps_to_ras) + header[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS + header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy() + header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) + if self.channel_dim is None: # default to "no_channel" or -1 + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + else: + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + _copy_compatible_dict(header, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + def _get_meta_dict(self, img) -> dict: + """ + Get all the metadata of the image and convert to dict type. + + Args: + img: an ITK image object loaded from an image file. + + """ + img_meta_dict = img.GetMetaDataDictionary() + meta_dict = {} + for key in img_meta_dict.GetKeys(): + if key.startswith("ITK_"): + continue + val = img_meta_dict[key] + meta_dict[key] = np.asarray(val) if type(val).__name__.startswith("itk") else val + + meta_dict["spacing"] = np.asarray(img.GetSpacing()) + return meta_dict + + def _get_affine(self, img, lps_to_ras: bool = True): + """ + Get or construct the affine matrix of the image, it can be used to correct + spacing, orientation or execute spatial transforms. + + Args: + img: an ITK image object loaded from an image file. + lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to True. + + """ + direction = itk.array_from_matrix(img.GetDirection()) + spacing = np.asarray(img.GetSpacing()) + origin = np.asarray(img.GetOrigin()) + + direction = np.asarray(direction) + sr = min(max(direction.shape[0], 1), 3) + affine: np.ndarray = np.eye(sr + 1) + affine[:sr, :sr] = direction[:sr, :sr] @ np.diag(spacing[:sr]) + affine[:sr, -1] = origin[:sr] + if lps_to_ras: + affine = orientation_ras_lps(affine) + return affine + + def _get_spatial_shape(self, img): + """ + Get the spatial shape of `img`. + + Args: + img: an ITK image object loaded from an image file. + + """ + sr = itk.array_from_matrix(img.GetDirection()).shape[0] + sr = max(min(sr, 3), 1) + _size = list(itk.size(img)) + if isinstance(self.channel_dim, int): + _size.pop(self.channel_dim) + return np.asarray(_size[:sr]) + + def _get_array_data(self, img): + """ + Get the raw array data of the image, converted to Numpy array. + + Following PyTorch conventions, the returned array data has contiguous channels, + e.g. for an RGB image, all red channel image pixels are contiguous in memory. + The last axis of the returned array is the channel axis. + + See also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Modules/Bridge/NumPy/wrapping/PyBuffer.i.in + + Args: + img: an ITK image object loaded from an image file. + + """ + np_img = itk.array_view_from_image(img, keep_axes=False) + if img.GetNumberOfComponentsPerPixel() == 1: # handling spatial images + return np_img if self.reverse_indexing else np_img.T + # handling multi-channel images + return np_img if self.reverse_indexing else np.moveaxis(np_img.T, 0, -1) + + +@require_pkg(pkg_name="pydicom") +class PydicomReader(ImageReader): + """ + Load medical images based on Pydicom library. + All the supported image formats can be found at: + https://dicom.nema.org/medical/dicom/current/output/chtml/part10/chapter_7.html + + PydicomReader is also able to load segmentations, if a dicom file contains tag: `SegmentSequence`, the reader + will consider it as segmentation data, and to load it successfully, `PerFrameFunctionalGroupsSequence` is required + for dicom file, and for each frame of dicom file, `SegmentIdentificationSequence` is required. + This method refers to the Highdicom library. + + This class refers to: + https://nipy.org/nibabel/dicom/dicom_orientation.html#dicom-affine-formula + https://github.com/pydicom/contrib-pydicom/blob/master/input-output/pydicom_series.py + https://highdicom.readthedocs.io/en/latest/usage.html#parsing-segmentation-seg-images + + Args: + channel_dim: the channel dimension of the input image, default is None. + This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. + If None, `original_channel_dim` will be either `no_channel` or `-1`. + affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. + Set to ``True`` to be consistent with ``NibabelReader``, + otherwise the affine matrix remains in the Dicom convention. + swap_ij: whether to swap the first two spatial axes. Default to ``True``, so that the outputs + are consistent with the other readers. + prune_metadata: whether to prune the saved information in metadata. This argument is used for + `get_data` function. If True, only items that are related to the affine matrix will be saved. + Default to ``True``. + label_dict: label of the dicom data. If provided, it will be used when loading segmentation data. + Keys of the dict are the classes, and values are the corresponding class number. For example: + for TCIA collection "C4KC-KiTS", it can be: {"Kidney": 0, "Renal Tumor": 1}. + fname_regex: a regular expression to match the file names when the input is a folder. + If provided, only the matched files will be included. For example, to include the file name + "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`. + Set it to `None` to use `pydicom.misc.is_dicom` to match valid files. + kwargs: additional args for `pydicom.dcmread` API. more details about available args: + https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html + If the `get_data` function will be called + (for example, when using this reader with `monai.transforms.LoadImage`), please ensure that the argument + `stop_before_pixels` is `True`, and `specific_tags` covers all necessary tags, such as `PixelSpacing`, + `ImagePositionPatient`, `ImageOrientationPatient` and all `pixel_array` related tags. + """ + + def __init__( + self, + channel_dim: str | int | None = None, + affine_lps_to_ras: bool = True, + swap_ij: bool = True, + prune_metadata: bool = True, + label_dict: dict | None = None, + fname_regex: str = "", + **kwargs, + ): + super().__init__() + self.kwargs = kwargs + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim + self.affine_lps_to_ras = affine_lps_to_ras + self.swap_ij = swap_ij + self.prune_metadata = prune_metadata + self.label_dict = label_dict + self.fname_regex = fname_regex + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified file or files format is supported by Pydicom reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + + """ + return has_pydicom + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + If passing directory path instead of file path, will treat it as DICOM images series and read. + + Args: + data: file name or a list of file names to read, + kwargs: additional args for `pydicom.dcmread` API, will override `self.kwargs` for existing keys. + + Returns: + If `data` represents a filename: return a pydicom dataset object. + If `data` represents a list of filenames or a directory: return a list of pydicom dataset object. + If `data` represents a list of directories: return a list of list of pydicom dataset object. + + """ + img_ = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + + self.has_series = False + + for name in filenames: + name = f"{name}" + if Path(name).is_dir(): + # read DICOM series + if self.fname_regex is not None: + series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if re.match(self.fname_regex, slc)] + else: + series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)] + slices = [] + for slc in series_slcs: + try: + slices.append(pydicom.dcmread(fp=slc, **kwargs_)) + except pydicom.errors.InvalidDicomError as e: + warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2) + img_.append(slices if len(slices) > 1 else slices[0]) + if len(slices) > 1: + self.has_series = True + else: + ds = pydicom.dcmread(fp=name, **kwargs_) + img_.append(ds) + return img_ if len(filenames) > 1 else img_[0] + + def _combine_dicom_series(self, data: Iterable): + """ + Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new + dimension as the last dimension. + + The stack order depends on Instance Number. The metadata will be based on the + first slice's metadata, and some new items will be added: + + "spacing": the new spacing of the stacked slices. + "lastImagePositionPatient": `ImagePositionPatient` for the last slice, it will be used to achieve the affine + matrix. + "spatial_shape": the spatial shape of the stacked slices. + + Args: + data: a list of pydicom dataset objects. + Returns: + a tuple that consisted with data array and metadata. + """ + slices: list = [] + # for a dicom series + for slc_ds in data: + if hasattr(slc_ds, "InstanceNumber"): + slices.append(slc_ds) + else: + warnings.warn(f"slice: {slc_ds.filename} does not have InstanceNumber tag, skip it.") + slices = sorted(slices, key=lambda s: s.InstanceNumber) + + if len(slices) == 0: + raise ValueError("the input does not have valid slices.") + + first_slice = slices[0] + average_distance = 0.0 + first_array = self._get_array_data(first_slice) + shape = first_array.shape + spacing = getattr(first_slice, "PixelSpacing", [1.0, 1.0, 1.0]) + prev_pos = getattr(first_slice, "ImagePositionPatient", (0.0, 0.0, 0.0))[2] + stack_array = [first_array] + for idx in range(1, len(slices)): + slc_array = self._get_array_data(slices[idx]) + slc_shape = slc_array.shape + slc_spacing = getattr(slices[idx], "PixelSpacing", (1.0, 1.0, 1.0)) + slc_pos = getattr(slices[idx], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2] + if not np.allclose(slc_spacing, spacing): + warnings.warn(f"the list contains slices that have different spacings {spacing} and {slc_spacing}.") + if shape != slc_shape: + warnings.warn(f"the list contains slices that have different shapes {shape} and {slc_shape}.") + average_distance += abs(prev_pos - slc_pos) + prev_pos = slc_pos + stack_array.append(slc_array) + + if len(slices) > 1: + average_distance /= len(slices) - 1 + spacing.append(average_distance) + stack_array = np.stack(stack_array, axis=-1) + stack_metadata = self._get_meta_dict(first_slice) + stack_metadata["spacing"] = np.asarray(spacing) + if hasattr(slices[-1], "ImagePositionPatient"): + stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1].ImagePositionPatient) + stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape + (len(slices),) + else: + stack_array = stack_array[0] + stack_metadata = self._get_meta_dict(first_slice) + stack_metadata["spacing"] = np.asarray(spacing) + stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape + + return stack_array, stack_metadata + + def get_data(self, data) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + For dicom series within the input, all slices will be stacked first, + When loading a list of files (dicom file, or stacked dicom series), they are stacked together at a new + dimension as the first dimension, and the metadata of the first image is used to represent the output metadata. + + To use this function, all pydicom dataset objects (if not segmentation data) should contain: + `pixel_array`, `PixelSpacing`, `ImagePositionPatient` and `ImageOrientationPatient`. + + For segmentation data, we assume that the input is not a dicom series, and the object should contain + `SegmentSequence` in order to identify it. + In addition, tags (5200, 9229) and (5200, 9230) are required to achieve + `PixelSpacing`, `ImageOrientationPatient` and `ImagePositionPatient`. + + Args: + data: a pydicom dataset object, or a list of pydicom dataset objects, or a list of list of + pydicom dataset objects. + + """ + + dicom_data = [] + # combine dicom series if exists + if self.has_series is True: + # a list, all objects within a list belong to one dicom series + if not isinstance(data[0], list): + dicom_data.append(self._combine_dicom_series(data)) + # a list of list, each inner list represents a dicom series + else: + for series in data: + dicom_data.append(self._combine_dicom_series(series)) + else: + # a single pydicom dataset object + if not isinstance(data, list): + data = [data] + for d in data: + if hasattr(d, "SegmentSequence"): + data_array, metadata = self._get_seg_data(d) + else: + data_array = self._get_array_data(d) + metadata = self._get_meta_dict(d) + metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape + dicom_data.append((data_array, metadata)) + + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + + for data_array, metadata in ensure_tuple(dicom_data): + img_array.append(np.ascontiguousarray(np.swapaxes(data_array, 0, 1) if self.swap_ij else data_array)) + affine = self._get_affine(metadata, self.affine_lps_to_ras) + metadata[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS + if self.swap_ij: + affine = affine @ np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + sp_size = list(metadata[MetaKeys.SPATIAL_SHAPE]) + sp_size[0], sp_size[1] = sp_size[1], sp_size[0] + metadata[MetaKeys.SPATIAL_SHAPE] = ensure_tuple(sp_size) + metadata[MetaKeys.ORIGINAL_AFFINE] = affine + metadata[MetaKeys.AFFINE] = affine.copy() + if self.channel_dim is None: # default to "no_channel" or -1 + metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data_array.shape) == len(metadata[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + else: + metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + metadata["spacing"] = affine_to_spacing( + metadata[MetaKeys.ORIGINAL_AFFINE], r=len(metadata[MetaKeys.SPATIAL_SHAPE]) + ) + + _copy_compatible_dict(metadata, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + def _get_meta_dict(self, img) -> dict: + """ + Get all the metadata of the image and convert to dict type. + + Args: + img: a Pydicom dataset object. + + """ + + metadata = img.to_json_dict(suppress_invalid_tags=True) + + if self.prune_metadata: + prune_metadata = {} + for key in ["00200037", "00200032", "00280030", "52009229", "52009230"]: + if key in metadata.keys(): + prune_metadata[key] = metadata[key] + return prune_metadata + + # always remove Pixel Data "7FE00008" or "7FE00009" or "7FE00010" + # always remove Data Set Trailing Padding "FFFCFFFC" + for key in ["7FE00008", "7FE00009", "7FE00010", "FFFCFFFC"]: + if key in metadata.keys(): + metadata.pop(key) + + return metadata # type: ignore + + def _get_affine(self, metadata: dict, lps_to_ras: bool = True): + """ + Get or construct the affine matrix of the image, it can be used to correct + spacing, orientation or execute spatial transforms. + + Args: + metadata: metadata with dict type. + lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to True. + + """ + affine: np.ndarray = np.eye(4) + if not ("00200037" in metadata and "00200032" in metadata): + return affine + # "00200037" is the tag of `ImageOrientationPatient` + rx, ry, rz, cx, cy, cz = metadata["00200037"]["Value"] + # "00200032" is the tag of `ImagePositionPatient` + sx, sy, sz = metadata["00200032"]["Value"] + # "00280030" is the tag of `PixelSpacing` + spacing = metadata["00280030"]["Value"] if "00280030" in metadata else (1.0, 1.0) + dr, dc = metadata.get("spacing", spacing)[:2] + affine[0, 0] = cx * dr + affine[0, 1] = rx * dc + affine[0, 3] = sx + affine[1, 0] = cy * dr + affine[1, 1] = ry * dc + affine[1, 3] = sy + affine[2, 0] = cz * dr + affine[2, 1] = rz * dc + affine[2, 2] = 1.0 + affine[2, 3] = sz + + # 3d + if "lastImagePositionPatient" in metadata: + t1n, t2n, t3n = metadata["lastImagePositionPatient"] + n = metadata[MetaKeys.SPATIAL_SHAPE][-1] + k1, k2, k3 = (t1n - sx) / (n - 1), (t2n - sy) / (n - 1), (t3n - sz) / (n - 1) + affine[0, 2] = k1 + affine[1, 2] = k2 + affine[2, 2] = k3 + + if lps_to_ras: + affine = orientation_ras_lps(affine) + return affine + + def _get_frame_data(self, img) -> Iterator: + """ + yield frames and description from the segmentation image. + This function is adapted from Highdicom: + https://github.com/herrmannlab/highdicom/blob/v0.18.2/src/highdicom/seg/utils.py + + which has the following license... + + # ========================================================================= + # https://github.com/herrmannlab/highdicom/blob/v0.18.2/LICENSE + # + # Copyright 2020 MGH Computational Pathology + # Permission is hereby granted, free of charge, to any person obtaining a + # copy of this software and associated documentation files (the + # "Software"), to deal in the Software without restriction, including + # without limitation the rights to use, copy, modify, merge, publish, + # distribute, sublicense, and/or sell copies of the Software, and to + # permit persons to whom the Software is furnished to do so, subject to + # the following conditions: + # The above copyright notice and this permission notice shall be included + # in all copies or substantial portions of the Software. + # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + # ========================================================================= + + (https://github.com/herrmannlab/highdicom/issues/188) + + Args: + img: a Pydicom dataset object that has attribute "SegmentSequence". + + """ + + if not hasattr(img, "PerFrameFunctionalGroupsSequence"): + raise NotImplementedError( + f"To read dicom seg: {img.filename}, 'PerFrameFunctionalGroupsSequence' is required." + ) + + frame_seg_nums = [] + for f in img.PerFrameFunctionalGroupsSequence: + if not hasattr(f, "SegmentIdentificationSequence"): + raise NotImplementedError( + f"To read dicom seg: {img.filename}, 'SegmentIdentificationSequence' is required for each frame." + ) + frame_seg_nums.append(int(f.SegmentIdentificationSequence[0].ReferencedSegmentNumber)) + + frame_seg_nums_arr = np.array(frame_seg_nums) + + seg_descriptions = {int(f.SegmentNumber): f for f in img.SegmentSequence} + + for i in np.unique(frame_seg_nums_arr): + indices = np.where(frame_seg_nums_arr == i)[0] + yield (img.pixel_array[indices, ...], seg_descriptions[i]) + + def _get_seg_data(self, img): + """ + Get the array data and metadata of the segmentation image. + + Aegs: + img: a Pydicom dataset object that has attribute "SegmentSequence". + + """ + + metadata = self._get_meta_dict(img) + n_classes = len(img.SegmentSequence) + spatial_shape = list(img.pixel_array.shape) + spatial_shape[0] = spatial_shape[0] // n_classes + + if self.label_dict is not None: + metadata["labels"] = self.label_dict + all_segs = np.zeros([*spatial_shape, len(self.label_dict)]) + else: + metadata["labels"] = {} + all_segs = np.zeros([*spatial_shape, n_classes]) + + for i, (frames, description) in enumerate(self._get_frame_data(img)): + segment_label = getattr(description, "SegmentLabel", f"label_{i}") + class_name = getattr(description, "SegmentDescription", segment_label) + if class_name not in metadata["labels"].keys(): + metadata["labels"][class_name] = i + class_num = metadata["labels"][class_name] + all_segs[..., class_num] = frames + + all_segs = all_segs.transpose([1, 2, 0, 3]) + metadata[MetaKeys.SPATIAL_SHAPE] = all_segs.shape[:-1] + + if "52009229" in metadata.keys(): + shared_func_group_seq = metadata["52009229"]["Value"][0] + + # get `ImageOrientationPatient` + if "00209116" in shared_func_group_seq.keys(): + plane_orient_seq = shared_func_group_seq["00209116"]["Value"][0] + if "00200037" in plane_orient_seq.keys(): + metadata["00200037"] = plane_orient_seq["00200037"] + + # get `PixelSpacing` + if "00289110" in shared_func_group_seq.keys(): + pixel_measure_seq = shared_func_group_seq["00289110"]["Value"][0] + + if "00280030" in pixel_measure_seq.keys(): + pixel_spacing = pixel_measure_seq["00280030"]["Value"] + metadata["spacing"] = pixel_spacing + if "00180050" in pixel_measure_seq.keys(): + metadata["spacing"] += pixel_measure_seq["00180050"]["Value"] + + if self.prune_metadata: + metadata.pop("52009229") + + # get `ImagePositionPatient` + if "52009230" in metadata.keys(): + first_frame_func_group_seq = metadata["52009230"]["Value"][0] + if "00209113" in first_frame_func_group_seq.keys(): + plane_position_seq = first_frame_func_group_seq["00209113"]["Value"][0] + if "00200032" in plane_position_seq.keys(): + metadata["00200032"] = plane_position_seq["00200032"] + metadata["lastImagePositionPatient"] = metadata["52009230"]["Value"][-1]["00209113"]["Value"][0][ + "00200032" + ]["Value"] + if self.prune_metadata: + metadata.pop("52009230") + + return all_segs, metadata + + def _get_array_data(self, img): + """ + Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data + will be rescaled. The output data has the dtype np.float32 if the rescaling is applied. + + Args: + img: a Pydicom dataset object. + + """ + # process Dicom series + if not hasattr(img, "pixel_array"): + raise ValueError(f"dicom data: {img.filename} does not have pixel_array.") + data = img.pixel_array + + slope, offset = 1.0, 0.0 + rescale_flag = False + if hasattr(img, "RescaleSlope"): + slope = img.RescaleSlope + rescale_flag = True + if hasattr(img, "RescaleIntercept"): + offset = img.RescaleIntercept + rescale_flag = True + if rescale_flag: + data = data.astype(np.float32) * slope + offset + + return data + + +@require_pkg(pkg_name="nibabel") +class NibabelReader(ImageReader): + """ + Load NIfTI format images based on Nibabel library. + + Args: + as_closest_canonical: if True, load the image as closest to canonical axis format. + squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) + channel_dim: the channel dimension of the input image, default is None. + this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. + if None, `original_channel_dim` will be either `no_channel` or `-1`. + most Nifti files are usually "channel last", no need to specify this argument for them. + kwargs: additional args for `nibabel.load` API. more details about available args: + https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py + + """ + + def __init__( + self, + channel_dim: str | int | None = None, + as_closest_canonical: bool = False, + squeeze_non_spatial_dims: bool = False, + **kwargs, + ): + super().__init__() + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim + self.as_closest_canonical = as_closest_canonical + self.squeeze_non_spatial_dims = squeeze_non_spatial_dims + self.kwargs = kwargs + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified file or files format is supported by Nibabel reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + + """ + suffixes: Sequence[str] = ["nii", "nii.gz"] + return has_nib and is_supported_format(filename, suffixes) + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + Note that the returned object is Nibabel image object or list of Nibabel image objects. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for `nibabel.load` API, will override `self.kwargs` for existing keys. + More details about available args: + https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py + + """ + img_: list[Nifti1Image] = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + img = nib.load(name, **kwargs_) + img = correct_nifti_header_if_necessary(img) + img_.append(img) # type: ignore + return img_ if len(filenames) > 1 else img_[0] + + def get_data(self, img) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + When loading a list of files, they are stacked together at a new dimension as the first dimension, + and the metadata of the first image is used to present the output metadata. + + Args: + img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. + + """ + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + + for i in ensure_tuple(img): + header = self._get_meta_dict(i) + header[MetaKeys.AFFINE] = self._get_affine(i) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) + header["as_closest_canonical"] = self.as_closest_canonical + if self.as_closest_canonical: + i = nib.as_closest_canonical(i) + header[MetaKeys.AFFINE] = self._get_affine(i) + header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) + header[MetaKeys.SPACE] = SpaceKeys.RAS + data = self._get_array_data(i) + if self.squeeze_non_spatial_dims: + for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1): + if data.shape[d - 1] == 1: + data = data.squeeze(axis=d - 1) + img_array.append(data) + if self.channel_dim is None: # default to "no_channel" or -1 + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + else: + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + _copy_compatible_dict(header, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + def _get_meta_dict(self, img) -> dict: + """ + Get the all the metadata of the image and convert to dict type. + + Args: + img: a Nibabel image object loaded from an image file. + + """ + # swap to little endian as PyTorch doesn't support big endian + try: + header = img.header.as_byteswapped("<") + except ValueError: + header = img.header + return dict(header) + + def _get_affine(self, img): + """ + Get the affine matrix of the image, it can be used to correct + spacing, orientation or execute spatial transforms. + + Args: + img: a Nibabel image object loaded from an image file. + + """ + return np.array(img.affine, copy=True) + + def _get_spatial_shape(self, img): + """ + Get the spatial shape of image data, it doesn't contain the channel dim. + + Args: + img: a Nibabel image object loaded from an image file. + + """ + # swap to little endian as PyTorch doesn't support big endian + try: + header = img.header.as_byteswapped("<") + except ValueError: + header = img.header + dim = header.get("dim", None) + if dim is None: + dim = header.get("dims") # mgh format? + dim = np.insert(dim, 0, 3) + ndim = dim[0] + size = list(dim[1:]) + if not is_no_channel(self.channel_dim): + size.pop(int(self.channel_dim)) # type: ignore + spatial_rank = max(min(ndim, 3), 1) + return np.asarray(size[:spatial_rank]) + + def _get_array_data(self, img): + """ + Get the raw array data of the image, converted to Numpy array. + + Args: + img: a Nibabel image object loaded from an image file. + + """ + return np.asanyarray(img.dataobj, order="C") + + +class NumpyReader(ImageReader): + """ + Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects. + A typical usage is to load the `mask` data for classification task. + It can load part of the npz file with specified `npz_keys`. + + Args: + npz_keys: if loading npz file, only load the specified keys, if None, load all the items. + stack the loaded items together to construct a new first dimension. + channel_dim: if not None, explicitly specify the channel dim, otherwise, treat the array as no channel. + kwargs: additional args for `numpy.load` API except `allow_pickle`. more details about available args: + https://numpy.org/doc/stable/reference/generated/numpy.load.html + + """ + + def __init__(self, npz_keys: KeysCollection | None = None, channel_dim: str | int | None = None, **kwargs): + super().__init__() + if npz_keys is not None: + npz_keys = ensure_tuple(npz_keys) + self.npz_keys = npz_keys + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim + self.kwargs = kwargs + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified file or files format is supported by Numpy reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + """ + suffixes: Sequence[str] = ["npz", "npy"] + return is_supported_format(filename, suffixes) + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of data files + and stack them together as multi-channel data in `get_data()`. + Note that the returned object is Numpy array or list of Numpy arrays. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for `numpy.load` API except `allow_pickle`, will override `self.kwargs` for existing keys. + More details about available args: + https://numpy.org/doc/stable/reference/generated/numpy.load.html + + """ + img_: list[Nifti1Image] = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + img = np.load(name, allow_pickle=True, **kwargs_) + if Path(name).name.endswith(".npz"): + # load expected items from NPZ file + npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys + for k in npz_keys: + img_.append(img[k]) + else: + img_.append(img) + + return img_ if len(img_) > 1 else img_[0] + + def get_data(self, img) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. + When loading a list of files, they are stacked together at a new dimension as the first dimension, + and the metadata of the first image is used to represent the output metadata. + + Args: + img: a Numpy array loaded from a file or a list of Numpy arrays. + + """ + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + if isinstance(img, np.ndarray): + img = (img,) + + for i in ensure_tuple(img): + header: dict[MetaKeys, Any] = {} + if isinstance(i, np.ndarray): + # if `channel_dim` is None, can not detect the channel dim, use all the dims as spatial_shape + spatial_shape = np.asarray(i.shape) + if isinstance(self.channel_dim, int): + spatial_shape = np.delete(spatial_shape, self.channel_dim) + header[MetaKeys.SPATIAL_SHAPE] = spatial_shape + header[MetaKeys.SPACE] = SpaceKeys.RAS + img_array.append(i) + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + self.channel_dim if isinstance(self.channel_dim, int) else float("nan") + ) + _copy_compatible_dict(header, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + +@require_pkg(pkg_name="PIL") +class PILReader(ImageReader): + """ + Load common 2D image format (supports PNG, JPG, BMP) file or files from provided path. + + Args: + converter: additional function to convert the image data after `read()`. + for example, use `converter=lambda image: image.convert("LA")` to convert image format. + reverse_indexing: whether to swap axis 0 and 1 after loading the array, this is enabled by default, + so that output of the reader is consistent with the other readers. Set this option to ``False`` to use + the PIL backend's original spatial axes convention. + kwargs: additional args for `Image.open` API in `read()`, mode details about available args: + https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open + """ + + def __init__(self, converter: Callable | None = None, reverse_indexing: bool = True, **kwargs): + super().__init__() + self.converter = converter + self.reverse_indexing = reverse_indexing + self.kwargs = kwargs + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified file or files format is supported by PIL reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + """ + suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"] + return has_pil and is_supported_format(filename, suffixes) + + def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + Note that the returned object is PIL image or list of PIL image. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for `Image.open` API in `read()`, will override `self.kwargs` for existing keys. + Mode details about available args: + https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open + + """ + img_: list[PILImage.Image] = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + img = PILImage.open(name, **kwargs_) + if callable(self.converter): + img = self.converter(img) + img_.append(img) + + return img_ if len(filenames) > 1 else img_[0] + + def get_data(self, img) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function returns two objects, first is numpy array of image data, second is dict of metadata. + It computes `spatial_shape` and stores it in meta dict. + When loading a list of files, they are stacked together at a new dimension as the first dimension, + and the metadata of the first image is used to represent the output metadata. + Note that by default `self.reverse_indexing` is set to ``True``, which swaps axis 0 and 1 after loading + the array because the spatial axes definition in PIL is different from other common medical packages. + + Args: + img: a PIL Image object loaded from a file or a list of PIL Image objects. + + """ + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + + for i in ensure_tuple(img): + header = self._get_meta_dict(i) + header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) + data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i) + img_array.append(data) + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + _copy_compatible_dict(header, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + def _get_meta_dict(self, img) -> dict: + """ + Get the all the metadata of the image and convert to dict type. + Args: + img: a PIL Image object loaded from an image file. + + """ + return {"format": img.format, "mode": img.mode, "width": img.width, "height": img.height} + + def _get_spatial_shape(self, img): + """ + Get the spatial shape of image data, it doesn't contain the channel dim. + Args: + img: a PIL Image object loaded from an image file. + """ + return np.asarray((img.width, img.height)) + + +@dataclass +class NrrdImage: + """Class to wrap nrrd image array and metadata header""" + + array: np.ndarray + header: dict + + +@require_pkg(pkg_name="nrrd") +class NrrdReader(ImageReader): + """ + Load NRRD format images based on pynrrd library. + + Args: + channel_dim: the channel dimension of the input image, default is None. + This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. + If None, `original_channel_dim` will be either `no_channel` or `0`. + NRRD files are usually "channel first". + dtype: dtype of the data array when loading image. + index_order: Specify whether the returned data array should be in C-order (‘C’) or Fortran-order (‘F’). + Numpy is usually in C-order, but default on the NRRD header is F + affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. + Set to ``True`` to be consistent with ``NibabelReader``, otherwise the affine matrix is unmodified. + + kwargs: additional args for `nrrd.read` API. more details about available args: + https://github.com/mhe/pynrrd/blob/master/nrrd/reader.py + + """ + + def __init__( + self, + channel_dim: str | int | None = None, + dtype: np.dtype | type | str | None = np.float32, + index_order: str = "F", + affine_lps_to_ras: bool = True, + **kwargs, + ): + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim + self.dtype = dtype + self.index_order = index_order + self.affine_lps_to_ras = affine_lps_to_ras + self.kwargs = kwargs + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified `filename` is supported by pynrrd reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + + """ + suffixes: Sequence[str] = ["nrrd", "seg.nrrd"] + return has_nrrd and is_supported_format(filename, suffixes) + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any: + """ + Read image data from specified file or files. + Note that it returns a data object or a sequence of data objects. + + Args: + data: file name or a list of file names to read. + kwargs: additional args for actual `read` API of 3rd party libs. + + """ + img_: list = [] + filenames: Sequence[PathLike] = ensure_tuple(data) + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, **kwargs_)) + img_.append(nrrd_image) + return img_ if len(filenames) > 1 else img_[0] + + def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded image and return them. + This function must return two objects, the first is a numpy array of image data, + the second is a dictionary of metadata. + + Args: + img: a `NrrdImage` loaded from an image file or a list of image objects. + + """ + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + + for i in ensure_tuple(img): + data = i.array.astype(self.dtype) + img_array.append(data) + header = dict(i.header) + if self.index_order == "C": + header = self._convert_f_to_c_order(header) + header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header) + + if self.affine_lps_to_ras: + header = self._switch_lps_ras(header) + if header.get(MetaKeys.SPACE, "left-posterior-superior") == "left-posterior-superior": + header[MetaKeys.SPACE] = SpaceKeys.LPS # assuming LPS if not specified + + header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy() + header[MetaKeys.SPATIAL_SHAPE] = header["sizes"] + [header.pop(k) for k in ("sizes", "space origin", "space directions")] # rm duplicated data in header + + if self.channel_dim is None: # default to "no_channel" or -1 + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0 + ) + else: + header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + _copy_compatible_dict(header, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + def _get_affine(self, header: dict) -> np.ndarray: + """ + Get the affine matrix of the image, it can be used to correct + spacing, orientation or execute spatial transforms. + + Args: + img: A `NrrdImage` loaded from image file + + """ + direction = header["space directions"] + origin = header["space origin"] + + x, y = direction.shape + affine_diam = min(x, y) + 1 + affine: np.ndarray = np.eye(affine_diam) + affine[:x, :y] = direction + affine[: (affine_diam - 1), -1] = origin # len origin is always affine_diam - 1 + return affine + + def _switch_lps_ras(self, header: dict) -> dict: + """ + For compatibility with nibabel, switch from LPS to RAS. Adapt affine matrix and + `space` argument in header accordingly. If no information of space is given in the header, + LPS is assumed and thus converted to RAS. If information about space is given, + but is not LPS, the unchanged header is returned. + + Args: + header: The image metadata as dict + + """ + if "space" not in header or header["space"] == "left-posterior-superior": + header[MetaKeys.ORIGINAL_AFFINE] = orientation_ras_lps(header[MetaKeys.ORIGINAL_AFFINE]) + header[MetaKeys.SPACE] = SpaceKeys.RAS + return header + + def _convert_f_to_c_order(self, header: dict) -> dict: + """ + All header fields of a NRRD are specified in `F` (Fortran) order, even if the image was read as C-ordered array. + 1D arrays of header['space origin'] and header['sizes'] become inverted, e.g, [1,2,3] -> [3,2,1] + The 2D Array for header['space directions'] is transposed: [[1,0,0],[0,2,0],[0,0,3]] -> [[3,0,0],[0,2,0],[0,0,1]] + For more details refer to: https://pynrrd.readthedocs.io/en/latest/user-guide.html#index-ordering + + Args: + header: The image metadata as dict + + """ + + header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) + header["space origin"] = header["space origin"][::-1] + header["sizes"] = header["sizes"][::-1] + return header \ No newline at end of file From e9f7565c4adc62b239e01d53c93224d1c04d2085 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 May 2024 09:37:14 +0000 Subject: [PATCH 03/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/image_reader.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 257bebc831..e7240e6b96 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -147,8 +147,8 @@ def _stack_images(image_list: list, meta_dict: dict): # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 return np.stack(image_list, axis=0) - - + + def update_json(input_file=None, output_file=None): record_path = "img-label.json" @@ -1433,28 +1433,13 @@ def _convert_f_to_c_order(self, header: dict) -> dict: from __future__ import annotations -import glob -import os -import re -import warnings from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable, Iterator, Sequence from dataclasses import dataclass -from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import numpy as np -from torch.utils.data._utils.collate import np_str_obj_array_pattern -from monai.config import KeysCollection, PathLike -from monai.data.utils import ( - affine_to_spacing, - correct_nifti_header_if_necessary, - is_no_channel, - is_supported_format, - orientation_ras_lps, -) -from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg +from monai.utils import optional_import, require_pkg if TYPE_CHECKING: import itk @@ -2816,4 +2801,4 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header \ No newline at end of file + return header From 7969d21613087a3432de6782e7590182e9a06614 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 14 May 2024 17:40:15 +0800 Subject: [PATCH 04/52] Fixes #7557 Remove changes unrelated to this issue. Signed-off-by: staydelight --- monai/data/image_writer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 06209c664a..4b7d95e71a 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -297,7 +297,7 @@ def resample_if_needed( # convert back at the end if isinstance(output_array, MetaTensor): output_array.applied_operations = [] - data_array, *_ = convert_data_type(output_array, output_type=orig_type) # type: ignore + data_array, *_ = convert_data_type(output_array, output_type=orig_type) affine, *_ = convert_data_type(output_array.affine, output_type=orig_type) # type: ignore return data_array[0], affine @@ -483,7 +483,6 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L809 """ - logger.info(f"ITKWriter is processing the file: {filename}") super().write(filename, verbose=verbose) super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( @@ -648,7 +647,6 @@ def write(self, filename: PathLike, verbose: bool = False, **obj_kwargs): - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.save """ - logger.info(f"NibabelWriter is processing the file: {filename}") super().write(filename, verbose=verbose) super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( From 3ce5f30f5034bd9f954676a2a62da56ffc17664d Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 14 May 2024 17:58:19 +0800 Subject: [PATCH 05/52] Fixes #7557 Remove changes unrelated to this issue. Signed-off-by: staydelight --- monai/data/image_reader.py | 1383 +----------------------------------- 1 file changed, 1 insertion(+), 1382 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index e7240e6b96..d11140c110 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1420,1385 +1420,4 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING - -import numpy as np - -from monai.utils import optional_import, require_pkg - -if TYPE_CHECKING: - import itk - import nibabel as nib - import nrrd - import pydicom - from nibabel.nifti1 import Nifti1Image - from PIL import Image as PILImage - - has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True -else: - itk, has_itk = optional_import("itk", allow_namespace_pkg=True) - nib, has_nib = optional_import("nibabel") - Nifti1Image, _ = optional_import("nibabel.nifti1", name="Nifti1Image") - PILImage, has_pil = optional_import("PIL.Image") - pydicom, has_pydicom = optional_import("pydicom") - nrrd, has_nrrd = optional_import("nrrd", allow_namespace_pkg=True) - -__all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] - - -class ImageReader(ABC): - """ - An abstract class defines APIs to load image files. - - Typical usage of an implementation of this class is: - - .. code-block:: python - - image_reader = MyImageReader() - img_obj = image_reader.read(path_to_image) - img_data, meta_data = image_reader.get_data(img_obj) - - - The `read` call converts image filenames into image objects, - - The `get_data` call fetches the image data, as well as metadata. - - A reader should implement `verify_suffix` with the logic of checking the input filename - by the filename extensions. - - """ - - @abstractmethod - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified `filename` is supported by the current reader. - This method should return True if the reader is able to read the format suggested by the - `filename`. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - - """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any: - """ - Read image data from specified file or files. - Note that it returns a data object or a sequence of data objects. - - Args: - data: file name or a list of file names to read. - kwargs: additional args for actual `read` API of 3rd party libs. - - """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - - @abstractmethod - def get_data(self, img) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function must return two objects, the first is a numpy array of image data, - the second is a dictionary of metadata. - - Args: - img: an image object loaded from an image file or a list of image objects. - - """ - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - - -def _copy_compatible_dict(from_dict: dict, to_dict: dict): - if not isinstance(to_dict, dict): - raise ValueError(f"to_dict must be a Dict, got {type(to_dict)}.") - if not to_dict: - for key in from_dict: - datum = from_dict[key] - if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None: - continue - to_dict[key] = str(TraceKeys.NONE) if datum is None else datum # NoneType to string for default_collate - else: - affine_key, shape_key = MetaKeys.AFFINE, MetaKeys.SPATIAL_SHAPE - if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]): - raise RuntimeError( - "affine matrix of all images should be the same for channel-wise concatenation. " - f"Got {from_dict[affine_key]} and {to_dict[affine_key]}." - ) - if shape_key in from_dict and not np.allclose(from_dict[shape_key], to_dict[shape_key]): - raise RuntimeError( - "spatial_shape of all images should be the same for channel-wise concatenation. " - f"Got {from_dict[shape_key]} and {to_dict[shape_key]}." - ) - - -def _stack_images(image_list: list, meta_dict: dict): - if len(image_list) <= 1: - return image_list[0] - if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)): - channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM]) - return np.concatenate(image_list, axis=channel_dim) - # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified - meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 - return np.stack(image_list, axis=0) - - -@require_pkg(pkg_name="itk") -class ITKReader(ImageReader): - """ - Load medical images based on ITK library. - All the supported image formats can be found at: - https://github.com/InsightSoftwareConsortium/ITK/tree/master/Modules/IO - The loaded data array will be in C order, for example, a 3D image NumPy - array index order will be `CDWH`. - - Args: - channel_dim: the channel dimension of the input image, default is None. - This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. - If None, `original_channel_dim` will be either `no_channel` or `-1`. - - - Nifti file is usually "channel last", so there is no need to specify this argument. - - PNG file usually has `GetNumberOfComponentsPerPixel()==3`, so there is no need to specify this argument. - - series_name: the name of the DICOM series if there are multiple ones. - used when loading DICOM series. - reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array. - If ``False``, the spatial indexing convention is reversed to be compatible with ITK; - otherwise, the spatial indexing follows the numpy convention. Default is ``False``. - This option does not affect the metadata. - series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice). - This flag is checked only when loading DICOM series. Default is ``False``. - affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. - Set to ``True`` to be consistent with ``NibabelReader``, otherwise the affine matrix remains in the ITK convention. - kwargs: additional args for `itk.imread` API. more details about available args: - https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py - - """ - - def __init__( - self, - channel_dim: str | int | None = None, - series_name: str = "", - reverse_indexing: bool = False, - series_meta: bool = False, - affine_lps_to_ras: bool = True, - **kwargs, - ): - super().__init__() - self.kwargs = kwargs - self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim - self.series_name = series_name - self.reverse_indexing = reverse_indexing - self.series_meta = series_meta - self.affine_lps_to_ras = affine_lps_to_ras - - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified file or files format is supported by ITK reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - - """ - return has_itk - - def read(self, data: Sequence[PathLike] | PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - If passing directory path instead of file path, will treat it as DICOM images series and read. - Note that the returned object is ITK image object or list of ITK image objects. - - Args: - data: file name or a list of file names to read, - kwargs: additional args for `itk.imread` API, will override `self.kwargs` for existing keys. - More details about available args: - https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py - - """ - img_ = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - name = f"{name}" - if Path(name).is_dir(): - # read DICOM series - # https://examples.itk.org/src/io/gdcm/readdicomseriesandwrite3dimage/documentation - names_generator = itk.GDCMSeriesFileNames.New() - names_generator.SetUseSeriesDetails(True) - names_generator.AddSeriesRestriction("0008|0021") # Series Date - names_generator.SetDirectory(name) - series_uid = names_generator.GetSeriesUIDs() - - if len(series_uid) < 1: - raise FileNotFoundError(f"no DICOMs in: {name}.") - if len(series_uid) > 1: - warnings.warn(f"the directory: {name} contains more than one DICOM series.") - series_identifier = series_uid[0] if not self.series_name else self.series_name - name = names_generator.GetFileNames(series_identifier) - - name = name[0] if len(name) == 1 else name # type: ignore - _obj = itk.imread(name, **kwargs_) - if self.series_meta: - _reader = itk.ImageSeriesReader.New(FileNames=name) - _reader.Update() - _meta = _reader.GetMetaDataDictionaryArray() - if len(_meta) > 0: - # TODO: using the first slice's meta. this could be improved to filter unnecessary tags. - _obj.SetMetaDataDictionary(_meta[0]) - img_.append(_obj) - else: - img_.append(itk.imread(name, **kwargs_)) - return img_ if len(filenames) > 1 else img_[0] - - def get_data(self, img) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. - When loading a list of files, they are stacked together at a new dimension as the first dimension, - and the metadata of the first image is used to represent the output metadata. - - Args: - img: an ITK image object loaded from an image file or a list of ITK image objects. - - """ - img_array: list[np.ndarray] = [] - compatible_meta: dict = {} - - for i in ensure_tuple(img): - data = self._get_array_data(i) - img_array.append(data) - header = self._get_meta_dict(i) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i, self.affine_lps_to_ras) - header[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS - header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy() - header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) - if self.channel_dim is None: # default to "no_channel" or -1 - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - else: - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - _copy_compatible_dict(header, compatible_meta) - - return _stack_images(img_array, compatible_meta), compatible_meta - - def _get_meta_dict(self, img) -> dict: - """ - Get all the metadata of the image and convert to dict type. - - Args: - img: an ITK image object loaded from an image file. - - """ - img_meta_dict = img.GetMetaDataDictionary() - meta_dict = {} - for key in img_meta_dict.GetKeys(): - if key.startswith("ITK_"): - continue - val = img_meta_dict[key] - meta_dict[key] = np.asarray(val) if type(val).__name__.startswith("itk") else val - - meta_dict["spacing"] = np.asarray(img.GetSpacing()) - return meta_dict - - def _get_affine(self, img, lps_to_ras: bool = True): - """ - Get or construct the affine matrix of the image, it can be used to correct - spacing, orientation or execute spatial transforms. - - Args: - img: an ITK image object loaded from an image file. - lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to True. - - """ - direction = itk.array_from_matrix(img.GetDirection()) - spacing = np.asarray(img.GetSpacing()) - origin = np.asarray(img.GetOrigin()) - - direction = np.asarray(direction) - sr = min(max(direction.shape[0], 1), 3) - affine: np.ndarray = np.eye(sr + 1) - affine[:sr, :sr] = direction[:sr, :sr] @ np.diag(spacing[:sr]) - affine[:sr, -1] = origin[:sr] - if lps_to_ras: - affine = orientation_ras_lps(affine) - return affine - - def _get_spatial_shape(self, img): - """ - Get the spatial shape of `img`. - - Args: - img: an ITK image object loaded from an image file. - - """ - sr = itk.array_from_matrix(img.GetDirection()).shape[0] - sr = max(min(sr, 3), 1) - _size = list(itk.size(img)) - if isinstance(self.channel_dim, int): - _size.pop(self.channel_dim) - return np.asarray(_size[:sr]) - - def _get_array_data(self, img): - """ - Get the raw array data of the image, converted to Numpy array. - - Following PyTorch conventions, the returned array data has contiguous channels, - e.g. for an RGB image, all red channel image pixels are contiguous in memory. - The last axis of the returned array is the channel axis. - - See also: - - - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Modules/Bridge/NumPy/wrapping/PyBuffer.i.in - - Args: - img: an ITK image object loaded from an image file. - - """ - np_img = itk.array_view_from_image(img, keep_axes=False) - if img.GetNumberOfComponentsPerPixel() == 1: # handling spatial images - return np_img if self.reverse_indexing else np_img.T - # handling multi-channel images - return np_img if self.reverse_indexing else np.moveaxis(np_img.T, 0, -1) - - -@require_pkg(pkg_name="pydicom") -class PydicomReader(ImageReader): - """ - Load medical images based on Pydicom library. - All the supported image formats can be found at: - https://dicom.nema.org/medical/dicom/current/output/chtml/part10/chapter_7.html - - PydicomReader is also able to load segmentations, if a dicom file contains tag: `SegmentSequence`, the reader - will consider it as segmentation data, and to load it successfully, `PerFrameFunctionalGroupsSequence` is required - for dicom file, and for each frame of dicom file, `SegmentIdentificationSequence` is required. - This method refers to the Highdicom library. - - This class refers to: - https://nipy.org/nibabel/dicom/dicom_orientation.html#dicom-affine-formula - https://github.com/pydicom/contrib-pydicom/blob/master/input-output/pydicom_series.py - https://highdicom.readthedocs.io/en/latest/usage.html#parsing-segmentation-seg-images - - Args: - channel_dim: the channel dimension of the input image, default is None. - This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. - If None, `original_channel_dim` will be either `no_channel` or `-1`. - affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. - Set to ``True`` to be consistent with ``NibabelReader``, - otherwise the affine matrix remains in the Dicom convention. - swap_ij: whether to swap the first two spatial axes. Default to ``True``, so that the outputs - are consistent with the other readers. - prune_metadata: whether to prune the saved information in metadata. This argument is used for - `get_data` function. If True, only items that are related to the affine matrix will be saved. - Default to ``True``. - label_dict: label of the dicom data. If provided, it will be used when loading segmentation data. - Keys of the dict are the classes, and values are the corresponding class number. For example: - for TCIA collection "C4KC-KiTS", it can be: {"Kidney": 0, "Renal Tumor": 1}. - fname_regex: a regular expression to match the file names when the input is a folder. - If provided, only the matched files will be included. For example, to include the file name - "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`. - Set it to `None` to use `pydicom.misc.is_dicom` to match valid files. - kwargs: additional args for `pydicom.dcmread` API. more details about available args: - https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html - If the `get_data` function will be called - (for example, when using this reader with `monai.transforms.LoadImage`), please ensure that the argument - `stop_before_pixels` is `True`, and `specific_tags` covers all necessary tags, such as `PixelSpacing`, - `ImagePositionPatient`, `ImageOrientationPatient` and all `pixel_array` related tags. - """ - - def __init__( - self, - channel_dim: str | int | None = None, - affine_lps_to_ras: bool = True, - swap_ij: bool = True, - prune_metadata: bool = True, - label_dict: dict | None = None, - fname_regex: str = "", - **kwargs, - ): - super().__init__() - self.kwargs = kwargs - self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim - self.affine_lps_to_ras = affine_lps_to_ras - self.swap_ij = swap_ij - self.prune_metadata = prune_metadata - self.label_dict = label_dict - self.fname_regex = fname_regex - - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified file or files format is supported by Pydicom reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - - """ - return has_pydicom - - def read(self, data: Sequence[PathLike] | PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - If passing directory path instead of file path, will treat it as DICOM images series and read. - - Args: - data: file name or a list of file names to read, - kwargs: additional args for `pydicom.dcmread` API, will override `self.kwargs` for existing keys. - - Returns: - If `data` represents a filename: return a pydicom dataset object. - If `data` represents a list of filenames or a directory: return a list of pydicom dataset object. - If `data` represents a list of directories: return a list of list of pydicom dataset object. - - """ - img_ = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - - self.has_series = False - - for name in filenames: - name = f"{name}" - if Path(name).is_dir(): - # read DICOM series - if self.fname_regex is not None: - series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if re.match(self.fname_regex, slc)] - else: - series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)] - slices = [] - for slc in series_slcs: - try: - slices.append(pydicom.dcmread(fp=slc, **kwargs_)) - except pydicom.errors.InvalidDicomError as e: - warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2) - img_.append(slices if len(slices) > 1 else slices[0]) - if len(slices) > 1: - self.has_series = True - else: - ds = pydicom.dcmread(fp=name, **kwargs_) - img_.append(ds) - return img_ if len(filenames) > 1 else img_[0] - - def _combine_dicom_series(self, data: Iterable): - """ - Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new - dimension as the last dimension. - - The stack order depends on Instance Number. The metadata will be based on the - first slice's metadata, and some new items will be added: - - "spacing": the new spacing of the stacked slices. - "lastImagePositionPatient": `ImagePositionPatient` for the last slice, it will be used to achieve the affine - matrix. - "spatial_shape": the spatial shape of the stacked slices. - - Args: - data: a list of pydicom dataset objects. - Returns: - a tuple that consisted with data array and metadata. - """ - slices: list = [] - # for a dicom series - for slc_ds in data: - if hasattr(slc_ds, "InstanceNumber"): - slices.append(slc_ds) - else: - warnings.warn(f"slice: {slc_ds.filename} does not have InstanceNumber tag, skip it.") - slices = sorted(slices, key=lambda s: s.InstanceNumber) - - if len(slices) == 0: - raise ValueError("the input does not have valid slices.") - - first_slice = slices[0] - average_distance = 0.0 - first_array = self._get_array_data(first_slice) - shape = first_array.shape - spacing = getattr(first_slice, "PixelSpacing", [1.0, 1.0, 1.0]) - prev_pos = getattr(first_slice, "ImagePositionPatient", (0.0, 0.0, 0.0))[2] - stack_array = [first_array] - for idx in range(1, len(slices)): - slc_array = self._get_array_data(slices[idx]) - slc_shape = slc_array.shape - slc_spacing = getattr(slices[idx], "PixelSpacing", (1.0, 1.0, 1.0)) - slc_pos = getattr(slices[idx], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2] - if not np.allclose(slc_spacing, spacing): - warnings.warn(f"the list contains slices that have different spacings {spacing} and {slc_spacing}.") - if shape != slc_shape: - warnings.warn(f"the list contains slices that have different shapes {shape} and {slc_shape}.") - average_distance += abs(prev_pos - slc_pos) - prev_pos = slc_pos - stack_array.append(slc_array) - - if len(slices) > 1: - average_distance /= len(slices) - 1 - spacing.append(average_distance) - stack_array = np.stack(stack_array, axis=-1) - stack_metadata = self._get_meta_dict(first_slice) - stack_metadata["spacing"] = np.asarray(spacing) - if hasattr(slices[-1], "ImagePositionPatient"): - stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1].ImagePositionPatient) - stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape + (len(slices),) - else: - stack_array = stack_array[0] - stack_metadata = self._get_meta_dict(first_slice) - stack_metadata["spacing"] = np.asarray(spacing) - stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape - - return stack_array, stack_metadata - - def get_data(self, data) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. - For dicom series within the input, all slices will be stacked first, - When loading a list of files (dicom file, or stacked dicom series), they are stacked together at a new - dimension as the first dimension, and the metadata of the first image is used to represent the output metadata. - - To use this function, all pydicom dataset objects (if not segmentation data) should contain: - `pixel_array`, `PixelSpacing`, `ImagePositionPatient` and `ImageOrientationPatient`. - - For segmentation data, we assume that the input is not a dicom series, and the object should contain - `SegmentSequence` in order to identify it. - In addition, tags (5200, 9229) and (5200, 9230) are required to achieve - `PixelSpacing`, `ImageOrientationPatient` and `ImagePositionPatient`. - - Args: - data: a pydicom dataset object, or a list of pydicom dataset objects, or a list of list of - pydicom dataset objects. - - """ - - dicom_data = [] - # combine dicom series if exists - if self.has_series is True: - # a list, all objects within a list belong to one dicom series - if not isinstance(data[0], list): - dicom_data.append(self._combine_dicom_series(data)) - # a list of list, each inner list represents a dicom series - else: - for series in data: - dicom_data.append(self._combine_dicom_series(series)) - else: - # a single pydicom dataset object - if not isinstance(data, list): - data = [data] - for d in data: - if hasattr(d, "SegmentSequence"): - data_array, metadata = self._get_seg_data(d) - else: - data_array = self._get_array_data(d) - metadata = self._get_meta_dict(d) - metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape - dicom_data.append((data_array, metadata)) - - img_array: list[np.ndarray] = [] - compatible_meta: dict = {} - - for data_array, metadata in ensure_tuple(dicom_data): - img_array.append(np.ascontiguousarray(np.swapaxes(data_array, 0, 1) if self.swap_ij else data_array)) - affine = self._get_affine(metadata, self.affine_lps_to_ras) - metadata[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS - if self.swap_ij: - affine = affine @ np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) - sp_size = list(metadata[MetaKeys.SPATIAL_SHAPE]) - sp_size[0], sp_size[1] = sp_size[1], sp_size[0] - metadata[MetaKeys.SPATIAL_SHAPE] = ensure_tuple(sp_size) - metadata[MetaKeys.ORIGINAL_AFFINE] = affine - metadata[MetaKeys.AFFINE] = affine.copy() - if self.channel_dim is None: # default to "no_channel" or -1 - metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data_array.shape) == len(metadata[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - else: - metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - metadata["spacing"] = affine_to_spacing( - metadata[MetaKeys.ORIGINAL_AFFINE], r=len(metadata[MetaKeys.SPATIAL_SHAPE]) - ) - - _copy_compatible_dict(metadata, compatible_meta) - - return _stack_images(img_array, compatible_meta), compatible_meta - - def _get_meta_dict(self, img) -> dict: - """ - Get all the metadata of the image and convert to dict type. - - Args: - img: a Pydicom dataset object. - - """ - - metadata = img.to_json_dict(suppress_invalid_tags=True) - - if self.prune_metadata: - prune_metadata = {} - for key in ["00200037", "00200032", "00280030", "52009229", "52009230"]: - if key in metadata.keys(): - prune_metadata[key] = metadata[key] - return prune_metadata - - # always remove Pixel Data "7FE00008" or "7FE00009" or "7FE00010" - # always remove Data Set Trailing Padding "FFFCFFFC" - for key in ["7FE00008", "7FE00009", "7FE00010", "FFFCFFFC"]: - if key in metadata.keys(): - metadata.pop(key) - - return metadata # type: ignore - - def _get_affine(self, metadata: dict, lps_to_ras: bool = True): - """ - Get or construct the affine matrix of the image, it can be used to correct - spacing, orientation or execute spatial transforms. - - Args: - metadata: metadata with dict type. - lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to True. - - """ - affine: np.ndarray = np.eye(4) - if not ("00200037" in metadata and "00200032" in metadata): - return affine - # "00200037" is the tag of `ImageOrientationPatient` - rx, ry, rz, cx, cy, cz = metadata["00200037"]["Value"] - # "00200032" is the tag of `ImagePositionPatient` - sx, sy, sz = metadata["00200032"]["Value"] - # "00280030" is the tag of `PixelSpacing` - spacing = metadata["00280030"]["Value"] if "00280030" in metadata else (1.0, 1.0) - dr, dc = metadata.get("spacing", spacing)[:2] - affine[0, 0] = cx * dr - affine[0, 1] = rx * dc - affine[0, 3] = sx - affine[1, 0] = cy * dr - affine[1, 1] = ry * dc - affine[1, 3] = sy - affine[2, 0] = cz * dr - affine[2, 1] = rz * dc - affine[2, 2] = 1.0 - affine[2, 3] = sz - - # 3d - if "lastImagePositionPatient" in metadata: - t1n, t2n, t3n = metadata["lastImagePositionPatient"] - n = metadata[MetaKeys.SPATIAL_SHAPE][-1] - k1, k2, k3 = (t1n - sx) / (n - 1), (t2n - sy) / (n - 1), (t3n - sz) / (n - 1) - affine[0, 2] = k1 - affine[1, 2] = k2 - affine[2, 2] = k3 - - if lps_to_ras: - affine = orientation_ras_lps(affine) - return affine - - def _get_frame_data(self, img) -> Iterator: - """ - yield frames and description from the segmentation image. - This function is adapted from Highdicom: - https://github.com/herrmannlab/highdicom/blob/v0.18.2/src/highdicom/seg/utils.py - - which has the following license... - - # ========================================================================= - # https://github.com/herrmannlab/highdicom/blob/v0.18.2/LICENSE - # - # Copyright 2020 MGH Computational Pathology - # Permission is hereby granted, free of charge, to any person obtaining a - # copy of this software and associated documentation files (the - # "Software"), to deal in the Software without restriction, including - # without limitation the rights to use, copy, modify, merge, publish, - # distribute, sublicense, and/or sell copies of the Software, and to - # permit persons to whom the Software is furnished to do so, subject to - # the following conditions: - # The above copyright notice and this permission notice shall be included - # in all copies or substantial portions of the Software. - # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - # ========================================================================= - - (https://github.com/herrmannlab/highdicom/issues/188) - - Args: - img: a Pydicom dataset object that has attribute "SegmentSequence". - - """ - - if not hasattr(img, "PerFrameFunctionalGroupsSequence"): - raise NotImplementedError( - f"To read dicom seg: {img.filename}, 'PerFrameFunctionalGroupsSequence' is required." - ) - - frame_seg_nums = [] - for f in img.PerFrameFunctionalGroupsSequence: - if not hasattr(f, "SegmentIdentificationSequence"): - raise NotImplementedError( - f"To read dicom seg: {img.filename}, 'SegmentIdentificationSequence' is required for each frame." - ) - frame_seg_nums.append(int(f.SegmentIdentificationSequence[0].ReferencedSegmentNumber)) - - frame_seg_nums_arr = np.array(frame_seg_nums) - - seg_descriptions = {int(f.SegmentNumber): f for f in img.SegmentSequence} - - for i in np.unique(frame_seg_nums_arr): - indices = np.where(frame_seg_nums_arr == i)[0] - yield (img.pixel_array[indices, ...], seg_descriptions[i]) - - def _get_seg_data(self, img): - """ - Get the array data and metadata of the segmentation image. - - Aegs: - img: a Pydicom dataset object that has attribute "SegmentSequence". - - """ - - metadata = self._get_meta_dict(img) - n_classes = len(img.SegmentSequence) - spatial_shape = list(img.pixel_array.shape) - spatial_shape[0] = spatial_shape[0] // n_classes - - if self.label_dict is not None: - metadata["labels"] = self.label_dict - all_segs = np.zeros([*spatial_shape, len(self.label_dict)]) - else: - metadata["labels"] = {} - all_segs = np.zeros([*spatial_shape, n_classes]) - - for i, (frames, description) in enumerate(self._get_frame_data(img)): - segment_label = getattr(description, "SegmentLabel", f"label_{i}") - class_name = getattr(description, "SegmentDescription", segment_label) - if class_name not in metadata["labels"].keys(): - metadata["labels"][class_name] = i - class_num = metadata["labels"][class_name] - all_segs[..., class_num] = frames - - all_segs = all_segs.transpose([1, 2, 0, 3]) - metadata[MetaKeys.SPATIAL_SHAPE] = all_segs.shape[:-1] - - if "52009229" in metadata.keys(): - shared_func_group_seq = metadata["52009229"]["Value"][0] - - # get `ImageOrientationPatient` - if "00209116" in shared_func_group_seq.keys(): - plane_orient_seq = shared_func_group_seq["00209116"]["Value"][0] - if "00200037" in plane_orient_seq.keys(): - metadata["00200037"] = plane_orient_seq["00200037"] - - # get `PixelSpacing` - if "00289110" in shared_func_group_seq.keys(): - pixel_measure_seq = shared_func_group_seq["00289110"]["Value"][0] - - if "00280030" in pixel_measure_seq.keys(): - pixel_spacing = pixel_measure_seq["00280030"]["Value"] - metadata["spacing"] = pixel_spacing - if "00180050" in pixel_measure_seq.keys(): - metadata["spacing"] += pixel_measure_seq["00180050"]["Value"] - - if self.prune_metadata: - metadata.pop("52009229") - - # get `ImagePositionPatient` - if "52009230" in metadata.keys(): - first_frame_func_group_seq = metadata["52009230"]["Value"][0] - if "00209113" in first_frame_func_group_seq.keys(): - plane_position_seq = first_frame_func_group_seq["00209113"]["Value"][0] - if "00200032" in plane_position_seq.keys(): - metadata["00200032"] = plane_position_seq["00200032"] - metadata["lastImagePositionPatient"] = metadata["52009230"]["Value"][-1]["00209113"]["Value"][0][ - "00200032" - ]["Value"] - if self.prune_metadata: - metadata.pop("52009230") - - return all_segs, metadata - - def _get_array_data(self, img): - """ - Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data - will be rescaled. The output data has the dtype np.float32 if the rescaling is applied. - - Args: - img: a Pydicom dataset object. - - """ - # process Dicom series - if not hasattr(img, "pixel_array"): - raise ValueError(f"dicom data: {img.filename} does not have pixel_array.") - data = img.pixel_array - - slope, offset = 1.0, 0.0 - rescale_flag = False - if hasattr(img, "RescaleSlope"): - slope = img.RescaleSlope - rescale_flag = True - if hasattr(img, "RescaleIntercept"): - offset = img.RescaleIntercept - rescale_flag = True - if rescale_flag: - data = data.astype(np.float32) * slope + offset - - return data - - -@require_pkg(pkg_name="nibabel") -class NibabelReader(ImageReader): - """ - Load NIfTI format images based on Nibabel library. - - Args: - as_closest_canonical: if True, load the image as closest to canonical axis format. - squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) - channel_dim: the channel dimension of the input image, default is None. - this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. - if None, `original_channel_dim` will be either `no_channel` or `-1`. - most Nifti files are usually "channel last", no need to specify this argument for them. - kwargs: additional args for `nibabel.load` API. more details about available args: - https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py - - """ - - def __init__( - self, - channel_dim: str | int | None = None, - as_closest_canonical: bool = False, - squeeze_non_spatial_dims: bool = False, - **kwargs, - ): - super().__init__() - self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim - self.as_closest_canonical = as_closest_canonical - self.squeeze_non_spatial_dims = squeeze_non_spatial_dims - self.kwargs = kwargs - - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified file or files format is supported by Nibabel reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - - """ - suffixes: Sequence[str] = ["nii", "nii.gz"] - return has_nib and is_supported_format(filename, suffixes) - - def read(self, data: Sequence[PathLike] | PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - Note that the returned object is Nibabel image object or list of Nibabel image objects. - - Args: - data: file name or a list of file names to read. - kwargs: additional args for `nibabel.load` API, will override `self.kwargs` for existing keys. - More details about available args: - https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py - - """ - img_: list[Nifti1Image] = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - img = nib.load(name, **kwargs_) - img = correct_nifti_header_if_necessary(img) - img_.append(img) # type: ignore - return img_ if len(filenames) > 1 else img_[0] - - def get_data(self, img) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. - When loading a list of files, they are stacked together at a new dimension as the first dimension, - and the metadata of the first image is used to present the output metadata. - - Args: - img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. - - """ - img_array: list[np.ndarray] = [] - compatible_meta: dict = {} - - for i in ensure_tuple(img): - header = self._get_meta_dict(i) - header[MetaKeys.AFFINE] = self._get_affine(i) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i) - header["as_closest_canonical"] = self.as_closest_canonical - if self.as_closest_canonical: - i = nib.as_closest_canonical(i) - header[MetaKeys.AFFINE] = self._get_affine(i) - header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) - header[MetaKeys.SPACE] = SpaceKeys.RAS - data = self._get_array_data(i) - if self.squeeze_non_spatial_dims: - for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1): - if data.shape[d - 1] == 1: - data = data.squeeze(axis=d - 1) - img_array.append(data) - if self.channel_dim is None: # default to "no_channel" or -1 - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - else: - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - _copy_compatible_dict(header, compatible_meta) - - return _stack_images(img_array, compatible_meta), compatible_meta - - def _get_meta_dict(self, img) -> dict: - """ - Get the all the metadata of the image and convert to dict type. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - # swap to little endian as PyTorch doesn't support big endian - try: - header = img.header.as_byteswapped("<") - except ValueError: - header = img.header - return dict(header) - - def _get_affine(self, img): - """ - Get the affine matrix of the image, it can be used to correct - spacing, orientation or execute spatial transforms. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - return np.array(img.affine, copy=True) - - def _get_spatial_shape(self, img): - """ - Get the spatial shape of image data, it doesn't contain the channel dim. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - # swap to little endian as PyTorch doesn't support big endian - try: - header = img.header.as_byteswapped("<") - except ValueError: - header = img.header - dim = header.get("dim", None) - if dim is None: - dim = header.get("dims") # mgh format? - dim = np.insert(dim, 0, 3) - ndim = dim[0] - size = list(dim[1:]) - if not is_no_channel(self.channel_dim): - size.pop(int(self.channel_dim)) # type: ignore - spatial_rank = max(min(ndim, 3), 1) - return np.asarray(size[:spatial_rank]) - - def _get_array_data(self, img): - """ - Get the raw array data of the image, converted to Numpy array. - - Args: - img: a Nibabel image object loaded from an image file. - - """ - return np.asanyarray(img.dataobj, order="C") - - -class NumpyReader(ImageReader): - """ - Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects. - A typical usage is to load the `mask` data for classification task. - It can load part of the npz file with specified `npz_keys`. - - Args: - npz_keys: if loading npz file, only load the specified keys, if None, load all the items. - stack the loaded items together to construct a new first dimension. - channel_dim: if not None, explicitly specify the channel dim, otherwise, treat the array as no channel. - kwargs: additional args for `numpy.load` API except `allow_pickle`. more details about available args: - https://numpy.org/doc/stable/reference/generated/numpy.load.html - - """ - - def __init__(self, npz_keys: KeysCollection | None = None, channel_dim: str | int | None = None, **kwargs): - super().__init__() - if npz_keys is not None: - npz_keys = ensure_tuple(npz_keys) - self.npz_keys = npz_keys - self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim - self.kwargs = kwargs - - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified file or files format is supported by Numpy reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - """ - suffixes: Sequence[str] = ["npz", "npy"] - return is_supported_format(filename, suffixes) - - def read(self, data: Sequence[PathLike] | PathLike, **kwargs): - """ - Read image data from specified file or files, it can read a list of data files - and stack them together as multi-channel data in `get_data()`. - Note that the returned object is Numpy array or list of Numpy arrays. - - Args: - data: file name or a list of file names to read. - kwargs: additional args for `numpy.load` API except `allow_pickle`, will override `self.kwargs` for existing keys. - More details about available args: - https://numpy.org/doc/stable/reference/generated/numpy.load.html - - """ - img_: list[Nifti1Image] = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - img = np.load(name, allow_pickle=True, **kwargs_) - if Path(name).name.endswith(".npz"): - # load expected items from NPZ file - npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys - for k in npz_keys: - img_.append(img[k]) - else: - img_.append(img) - - return img_ if len(img_) > 1 else img_[0] - - def get_data(self, img) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict. - When loading a list of files, they are stacked together at a new dimension as the first dimension, - and the metadata of the first image is used to represent the output metadata. - - Args: - img: a Numpy array loaded from a file or a list of Numpy arrays. - - """ - img_array: list[np.ndarray] = [] - compatible_meta: dict = {} - if isinstance(img, np.ndarray): - img = (img,) - - for i in ensure_tuple(img): - header: dict[MetaKeys, Any] = {} - if isinstance(i, np.ndarray): - # if `channel_dim` is None, can not detect the channel dim, use all the dims as spatial_shape - spatial_shape = np.asarray(i.shape) - if isinstance(self.channel_dim, int): - spatial_shape = np.delete(spatial_shape, self.channel_dim) - header[MetaKeys.SPATIAL_SHAPE] = spatial_shape - header[MetaKeys.SPACE] = SpaceKeys.RAS - img_array.append(i) - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - self.channel_dim if isinstance(self.channel_dim, int) else float("nan") - ) - _copy_compatible_dict(header, compatible_meta) - - return _stack_images(img_array, compatible_meta), compatible_meta - - -@require_pkg(pkg_name="PIL") -class PILReader(ImageReader): - """ - Load common 2D image format (supports PNG, JPG, BMP) file or files from provided path. - - Args: - converter: additional function to convert the image data after `read()`. - for example, use `converter=lambda image: image.convert("LA")` to convert image format. - reverse_indexing: whether to swap axis 0 and 1 after loading the array, this is enabled by default, - so that output of the reader is consistent with the other readers. Set this option to ``False`` to use - the PIL backend's original spatial axes convention. - kwargs: additional args for `Image.open` API in `read()`, mode details about available args: - https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open - """ - - def __init__(self, converter: Callable | None = None, reverse_indexing: bool = True, **kwargs): - super().__init__() - self.converter = converter - self.reverse_indexing = reverse_indexing - self.kwargs = kwargs - - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified file or files format is supported by PIL reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - """ - suffixes: Sequence[str] = ["png", "jpg", "jpeg", "bmp"] - return has_pil and is_supported_format(filename, suffixes) - - def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): - """ - Read image data from specified file or files, it can read a list of images - and stack them together as multi-channel data in `get_data()`. - Note that the returned object is PIL image or list of PIL image. - - Args: - data: file name or a list of file names to read. - kwargs: additional args for `Image.open` API in `read()`, will override `self.kwargs` for existing keys. - Mode details about available args: - https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open - - """ - img_: list[PILImage.Image] = [] - - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - img = PILImage.open(name, **kwargs_) - if callable(self.converter): - img = self.converter(img) - img_.append(img) - - return img_ if len(filenames) > 1 else img_[0] - - def get_data(self, img) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function returns two objects, first is numpy array of image data, second is dict of metadata. - It computes `spatial_shape` and stores it in meta dict. - When loading a list of files, they are stacked together at a new dimension as the first dimension, - and the metadata of the first image is used to represent the output metadata. - Note that by default `self.reverse_indexing` is set to ``True``, which swaps axis 0 and 1 after loading - the array because the spatial axes definition in PIL is different from other common medical packages. - - Args: - img: a PIL Image object loaded from a file or a list of PIL Image objects. - - """ - img_array: list[np.ndarray] = [] - compatible_meta: dict = {} - - for i in ensure_tuple(img): - header = self._get_meta_dict(i) - header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i) - data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i) - img_array.append(data) - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1 - ) - _copy_compatible_dict(header, compatible_meta) - - return _stack_images(img_array, compatible_meta), compatible_meta - - def _get_meta_dict(self, img) -> dict: - """ - Get the all the metadata of the image and convert to dict type. - Args: - img: a PIL Image object loaded from an image file. - - """ - return {"format": img.format, "mode": img.mode, "width": img.width, "height": img.height} - - def _get_spatial_shape(self, img): - """ - Get the spatial shape of image data, it doesn't contain the channel dim. - Args: - img: a PIL Image object loaded from an image file. - """ - return np.asarray((img.width, img.height)) - - -@dataclass -class NrrdImage: - """Class to wrap nrrd image array and metadata header""" - - array: np.ndarray - header: dict - - -@require_pkg(pkg_name="nrrd") -class NrrdReader(ImageReader): - """ - Load NRRD format images based on pynrrd library. - - Args: - channel_dim: the channel dimension of the input image, default is None. - This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field. - If None, `original_channel_dim` will be either `no_channel` or `0`. - NRRD files are usually "channel first". - dtype: dtype of the data array when loading image. - index_order: Specify whether the returned data array should be in C-order (‘C’) or Fortran-order (‘F’). - Numpy is usually in C-order, but default on the NRRD header is F - affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". Defaults to ``True``. - Set to ``True`` to be consistent with ``NibabelReader``, otherwise the affine matrix is unmodified. - - kwargs: additional args for `nrrd.read` API. more details about available args: - https://github.com/mhe/pynrrd/blob/master/nrrd/reader.py - - """ - - def __init__( - self, - channel_dim: str | int | None = None, - dtype: np.dtype | type | str | None = np.float32, - index_order: str = "F", - affine_lps_to_ras: bool = True, - **kwargs, - ): - self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim - self.dtype = dtype - self.index_order = index_order - self.affine_lps_to_ras = affine_lps_to_ras - self.kwargs = kwargs - - def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: - """ - Verify whether the specified `filename` is supported by pynrrd reader. - - Args: - filename: file name or a list of file names to read. - if a list of files, verify all the suffixes. - - """ - suffixes: Sequence[str] = ["nrrd", "seg.nrrd"] - return has_nrrd and is_supported_format(filename, suffixes) - - def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any: - """ - Read image data from specified file or files. - Note that it returns a data object or a sequence of data objects. - - Args: - data: file name or a list of file names to read. - kwargs: additional args for actual `read` API of 3rd party libs. - - """ - img_: list = [] - filenames: Sequence[PathLike] = ensure_tuple(data) - kwargs_ = self.kwargs.copy() - kwargs_.update(kwargs) - for name in filenames: - nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, **kwargs_)) - img_.append(nrrd_image) - return img_ if len(filenames) > 1 else img_[0] - - def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]: - """ - Extract data array and metadata from loaded image and return them. - This function must return two objects, the first is a numpy array of image data, - the second is a dictionary of metadata. - - Args: - img: a `NrrdImage` loaded from an image file or a list of image objects. - - """ - img_array: list[np.ndarray] = [] - compatible_meta: dict = {} - - for i in ensure_tuple(img): - data = i.array.astype(self.dtype) - img_array.append(data) - header = dict(i.header) - if self.index_order == "C": - header = self._convert_f_to_c_order(header) - header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header) - - if self.affine_lps_to_ras: - header = self._switch_lps_ras(header) - if header.get(MetaKeys.SPACE, "left-posterior-superior") == "left-posterior-superior": - header[MetaKeys.SPACE] = SpaceKeys.LPS # assuming LPS if not specified - - header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy() - header[MetaKeys.SPATIAL_SHAPE] = header["sizes"] - [header.pop(k) for k in ("sizes", "space origin", "space directions")] # rm duplicated data in header - - if self.channel_dim is None: # default to "no_channel" or -1 - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( - float("nan") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0 - ) - else: - header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim - _copy_compatible_dict(header, compatible_meta) - - return _stack_images(img_array, compatible_meta), compatible_meta - - def _get_affine(self, header: dict) -> np.ndarray: - """ - Get the affine matrix of the image, it can be used to correct - spacing, orientation or execute spatial transforms. - - Args: - img: A `NrrdImage` loaded from image file - - """ - direction = header["space directions"] - origin = header["space origin"] - - x, y = direction.shape - affine_diam = min(x, y) + 1 - affine: np.ndarray = np.eye(affine_diam) - affine[:x, :y] = direction - affine[: (affine_diam - 1), -1] = origin # len origin is always affine_diam - 1 - return affine - - def _switch_lps_ras(self, header: dict) -> dict: - """ - For compatibility with nibabel, switch from LPS to RAS. Adapt affine matrix and - `space` argument in header accordingly. If no information of space is given in the header, - LPS is assumed and thus converted to RAS. If information about space is given, - but is not LPS, the unchanged header is returned. - - Args: - header: The image metadata as dict - - """ - if "space" not in header or header["space"] == "left-posterior-superior": - header[MetaKeys.ORIGINAL_AFFINE] = orientation_ras_lps(header[MetaKeys.ORIGINAL_AFFINE]) - header[MetaKeys.SPACE] = SpaceKeys.RAS - return header - - def _convert_f_to_c_order(self, header: dict) -> dict: - """ - All header fields of a NRRD are specified in `F` (Fortran) order, even if the image was read as C-ordered array. - 1D arrays of header['space origin'] and header['sizes'] become inverted, e.g, [1,2,3] -> [3,2,1] - The 2D Array for header['space directions'] is transposed: [[1,0,0],[0,2,0],[0,0,3]] -> [[3,0,0],[0,2,0],[0,0,1]] - For more details refer to: https://pynrrd.readthedocs.io/en/latest/user-guide.html#index-ordering - - Args: - header: The image metadata as dict - - """ - - header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) - header["space origin"] = header["space origin"][::-1] - header["sizes"] = header["sizes"][::-1] - return header + return header \ No newline at end of file From 0699eebe43bc9e06b2fa35fb4cf93e8095e69e0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 May 2024 10:00:44 +0000 Subject: [PATCH 06/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/image_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index d11140c110..c33b62681c 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1420,4 +1420,4 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header \ No newline at end of file + return header From 274cd044a423c8d31cd28a6240547a0225c67be2 Mon Sep 17 00:00:00 2001 From: staydelight Date: Sun, 2 Jun 2024 00:05:51 +0800 Subject: [PATCH 07/52] fix-issue-7557 Signed-off-by: staydelight --- monai/data/image_reader.py | 28 +--------------------------- monai/data/image_writer.py | 26 +------------------------- monai/transforms/io/array.py | 22 ++++++++++++++++++++++ monai/transforms/io/dictionary.py | 2 ++ 4 files changed, 26 insertions(+), 52 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index c33b62681c..488d3df15e 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -12,7 +12,6 @@ from __future__ import annotations import glob -import json import os import re import warnings @@ -149,25 +148,6 @@ def _stack_images(image_list: list, meta_dict: dict): return np.stack(image_list, axis=0) -def update_json(input_file=None, output_file=None): - record_path = "img-label.json" - - if not os.path.exists(record_path) or os.stat(record_path).st_size == 0: - with open(record_path, 'w') as f: - json.dump([], f) - - with open(record_path, 'r+') as f: - records = json.load(f) - if input_file: - new_record = {"image": input_file, "label": []} - records.append(new_record) - elif output_file and records: - records[-1]["label"].append(output_file) - - f.seek(0) - json.dump(records, f, indent=4) - - @require_pkg(pkg_name="itk") class ITKReader(ImageReader): """ @@ -245,7 +225,6 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_ = [] filenames: Sequence[PathLike] = ensure_tuple(data) - update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -486,7 +465,6 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_ = [] filenames: Sequence[PathLike] = ensure_tuple(data) - update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) @@ -938,7 +916,6 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) - update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -1099,7 +1076,6 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_: list[Nifti1Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) - update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -1197,7 +1173,6 @@ def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs): img_: list[PILImage.Image] = [] filenames: Sequence[PathLike] = ensure_tuple(data) - update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -1322,7 +1297,6 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | """ img_: list = [] filenames: Sequence[PathLike] = ensure_tuple(data) - update_json(input_file=filenames) kwargs_ = self.kwargs.copy() kwargs_.update(kwargs) for name in filenames: @@ -1420,4 +1394,4 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header + return header \ No newline at end of file diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 4b7d95e71a..ba1c9dde27 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -15,8 +15,6 @@ from typing import TYPE_CHECKING, Any, cast import numpy as np -import os -import json from monai.apps.utils import get_logger from monai.config import DtypeLike, NdarrayOrTensor, PathLike @@ -198,25 +196,6 @@ def write(self, filename: PathLike, verbose: bool = True, **kwargs): if verbose: logger.info(f"writing: {filename}") - def update_json(self, input_file=None, output_file=None): - record_path = "img-label.json" - - if not os.path.exists(record_path) or os.stat(record_path).st_size == 0: - with open(record_path, 'w') as f: - json.dump([], f) - - with open(record_path, 'r+') as f: - records = json.load(f) - if input_file: - new_record = {"image": input_file, "label": []} - records.append(new_record) - elif output_file and records: - records[-1]["label"].append(output_file) - - f.seek(0) - json.dump(records, f, indent=4) - - @classmethod def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray: """ @@ -484,7 +463,6 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L809 """ super().write(filename, verbose=verbose) - super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( cast(NdarrayOrTensor, self.data_obj), channel_dim=self.channel_dim, @@ -648,7 +626,6 @@ def write(self, filename: PathLike, verbose: bool = False, **obj_kwargs): - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.save """ super().write(filename, verbose=verbose) - super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( cast(NdarrayOrTensor, self.data_obj), affine=self.affine, dtype=self.output_dtype, **obj_kwargs ) @@ -794,7 +771,6 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): - https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.save """ super().write(filename, verbose=verbose) - super().update_json(output_file=filename) self.data_obj = self.create_backend_obj( data_array=self.data_obj, dtype=self.output_dtype, @@ -895,4 +871,4 @@ def init(): for ext in ("nii.gz", "nii"): register_writer(ext, NibabelWriter, ITKWriter) register_writer("nrrd", ITKWriter, NibabelWriter) - register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter) + register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter) \ No newline at end of file diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 7222a26fc3..04492112b3 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -393,6 +393,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: FolderLayoutBase | None = None, savepath_in_metadict: bool = False, + mapping_log_path: Union[Path, str, None] = None ) -> None: self.folder_layout: FolderLayoutBase if folder_layout is None: @@ -438,6 +439,11 @@ def __init__( self.write_kwargs = {"verbose": print_log} self._data_index = 0 self.savepath_in_metadict = savepath_in_metadict + if mapping_log_path: + self.mapping_log_path = Path(mapping_log_path) + self.savepath_in_metadict = True + else: + self.mapping_log_path = None def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): """ @@ -506,6 +512,22 @@ def __call__( self._data_index += 1 if self.savepath_in_metadict and meta_data is not None: meta_data["saved_to"] = filename + if self.mapping_log_path and meta_data is not None: + log_data = [] + log_data.append({ + "input": meta_data.get("filename_or_obj", ()), + "output": meta_data.get("saved_to", ()) + }) + + try: + with open(self.mapping_log_path, 'r') as f: + existing_log_data = json.load(f) + except FileNotFoundError: + existing_log_data = [] + + with open(self.mapping_log_path, 'w') as f: + existing_log_data.extend(log_data) + json.dump(existing_log_data, f, indent=4) return img msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 4da1d422ca..966bf305ef 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -281,6 +281,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: monai.data.FolderLayoutBase | None = None, savepath_in_metadict: bool = False, + mapping_log_path: Union[Path, str, None] = None ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) @@ -304,6 +305,7 @@ def __init__( output_name_formatter=output_name_formatter, folder_layout=folder_layout, savepath_in_metadict=savepath_in_metadict, + mapping_log_path= mapping_log_path, ) def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): From d4fb0b7db72d77188d06299d6ad1b5457cc5ab52 Mon Sep 17 00:00:00 2001 From: staydelight Date: Sun, 2 Jun 2024 00:11:38 +0800 Subject: [PATCH 08/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 04492112b3..10cb2bca9f 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -14,6 +14,7 @@ from __future__ import annotations +import json import inspect import logging import sys From bfb6d58f4355682c51a0ac03c767ffd2e13daf98 Mon Sep 17 00:00:00 2001 From: staydelight Date: Sun, 2 Jun 2024 00:23:23 +0800 Subject: [PATCH 09/52] Fixes #7557 Add code for generating a mapping json file. Signed-off-by: staydelight --- monai/transforms/io/array.py | 14 +++++++------- monai/transforms/io/dictionary.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 10cb2bca9f..cd8bf5f638 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -394,7 +394,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_log_path: Union[Path, str, None] = None + mapping_json_path: Union[Path, str, None] = None ) -> None: self.folder_layout: FolderLayoutBase if folder_layout is None: @@ -440,11 +440,11 @@ def __init__( self.write_kwargs = {"verbose": print_log} self._data_index = 0 self.savepath_in_metadict = savepath_in_metadict - if mapping_log_path: - self.mapping_log_path = Path(mapping_log_path) + if mapping_json_path: + self.mapping_json_path = Path(mapping_json_path) self.savepath_in_metadict = True else: - self.mapping_log_path = None + self.mapping_json_path = None def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): """ @@ -513,7 +513,7 @@ def __call__( self._data_index += 1 if self.savepath_in_metadict and meta_data is not None: meta_data["saved_to"] = filename - if self.mapping_log_path and meta_data is not None: + if self.mapping_json_path and meta_data is not None: log_data = [] log_data.append({ "input": meta_data.get("filename_or_obj", ()), @@ -521,12 +521,12 @@ def __call__( }) try: - with open(self.mapping_log_path, 'r') as f: + with open(self.mapping_json_path, 'r') as f: existing_log_data = json.load(f) except FileNotFoundError: existing_log_data = [] - with open(self.mapping_log_path, 'w') as f: + with open(self.mapping_json_path, 'w') as f: existing_log_data.extend(log_data) json.dump(existing_log_data, f, indent=4) return img diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 966bf305ef..927e0ad718 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -281,7 +281,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: monai.data.FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_log_path: Union[Path, str, None] = None + mapping_json_path: Union[Path, str, None] = None ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) @@ -305,7 +305,7 @@ def __init__( output_name_formatter=output_name_formatter, folder_layout=folder_layout, savepath_in_metadict=savepath_in_metadict, - mapping_log_path= mapping_log_path, + mapping_json_path= mapping_json_path, ) def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): From 5ab2521a4dd310d344256548188c164a31b70b7f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 1 Jun 2024 16:32:24 +0000 Subject: [PATCH 10/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/image_reader.py | 2 +- monai/data/image_writer.py | 2 +- monai/transforms/io/array.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 488d3df15e..f5e199e2a3 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1394,4 +1394,4 @@ def _convert_f_to_c_order(self, header: dict) -> dict: header["space directions"] = np.rot90(np.flip(header["space directions"], 0)) header["space origin"] = header["space origin"][::-1] header["sizes"] = header["sizes"][::-1] - return header \ No newline at end of file + return header diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index ba1c9dde27..b9e8b9e68e 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -871,4 +871,4 @@ def init(): for ext in ("nii.gz", "nii"): register_writer(ext, NibabelWriter, ITKWriter) register_writer("nrrd", ITKWriter, NibabelWriter) - register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter) \ No newline at end of file + register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cd8bf5f638..961ca72111 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -443,7 +443,7 @@ def __init__( if mapping_json_path: self.mapping_json_path = Path(mapping_json_path) self.savepath_in_metadict = True - else: + else: self.mapping_json_path = None def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): @@ -519,9 +519,9 @@ def __call__( "input": meta_data.get("filename_or_obj", ()), "output": meta_data.get("saved_to", ()) }) - + try: - with open(self.mapping_json_path, 'r') as f: + with open(self.mapping_json_path) as f: existing_log_data = json.load(f) except FileNotFoundError: existing_log_data = [] From 894854deb0fa3e02f28fb472371a723e9918bc1c Mon Sep 17 00:00:00 2001 From: staydelight Date: Sun, 2 Jun 2024 00:54:54 +0800 Subject: [PATCH 11/52] Fixes #7557 Change mapping_json_path init way. Signed-off-by: staydelight --- monai/transforms/io/array.py | 8 ++++---- monai/transforms/io/dictionary.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 961ca72111..c4f396dc64 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -394,7 +394,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_json_path: Union[Path, str, None] = None + mapping_json_path: Path | str | None = None ) -> None: self.folder_layout: FolderLayoutBase if folder_layout is None: @@ -443,7 +443,7 @@ def __init__( if mapping_json_path: self.mapping_json_path = Path(mapping_json_path) self.savepath_in_metadict = True - else: + else: self.mapping_json_path = None def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): @@ -519,9 +519,9 @@ def __call__( "input": meta_data.get("filename_or_obj", ()), "output": meta_data.get("saved_to", ()) }) - + try: - with open(self.mapping_json_path) as f: + with open(self.mapping_json_path, 'r') as f: existing_log_data = json.load(f) except FileNotFoundError: existing_log_data = [] diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 927e0ad718..e3214777b9 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -281,7 +281,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: monai.data.FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_json_path: Union[Path, str, None] = None + mapping_json_path: Path | str | None = None ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) From c37222512ee5131de0190554e684666bd12fd687 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 1 Jun 2024 16:59:13 +0000 Subject: [PATCH 12/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/io/array.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index c4f396dc64..5601b2a20a 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -443,7 +443,7 @@ def __init__( if mapping_json_path: self.mapping_json_path = Path(mapping_json_path) self.savepath_in_metadict = True - else: + else: self.mapping_json_path = None def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): @@ -519,9 +519,9 @@ def __call__( "input": meta_data.get("filename_or_obj", ()), "output": meta_data.get("saved_to", ()) }) - + try: - with open(self.mapping_json_path, 'r') as f: + with open(self.mapping_json_path) as f: existing_log_data = json.load(f) except FileNotFoundError: existing_log_data = [] From 682379b2c7a71665a65572e580c933f90f0f3ffc Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 3 Jun 2024 16:25:42 +0800 Subject: [PATCH 13/52] Fixes #7557 Fixing unsuccessful checks. Signed-off-by: staydelight --- monai/transforms/io/array.py | 13 ++++++------- monai/transforms/io/dictionary.py | 4 ++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 5601b2a20a..cdcc4da80d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -14,8 +14,8 @@ from __future__ import annotations -import json import inspect +import json import logging import sys import traceback @@ -394,7 +394,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_json_path: Path | str | None = None + mapping_json_path: Path | str | None = None, ) -> None: self.folder_layout: FolderLayoutBase if folder_layout is None: @@ -515,10 +515,9 @@ def __call__( meta_data["saved_to"] = filename if self.mapping_json_path and meta_data is not None: log_data = [] - log_data.append({ - "input": meta_data.get("filename_or_obj", ()), - "output": meta_data.get("saved_to", ()) - }) + log_data.append( + {"input": meta_data.get("filename_or_obj", ()), "output": meta_data.get("saved_to", ())} + ) try: with open(self.mapping_json_path) as f: @@ -526,7 +525,7 @@ def __call__( except FileNotFoundError: existing_log_data = [] - with open(self.mapping_json_path, 'w') as f: + with open(self.mapping_json_path, "w") as f: existing_log_data.extend(log_data) json.dump(existing_log_data, f, indent=4) return img diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index e3214777b9..3cf46272c0 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -281,7 +281,7 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: monai.data.FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_json_path: Path | str | None = None + mapping_json_path: Path | str | None = None, ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) @@ -305,7 +305,7 @@ def __init__( output_name_formatter=output_name_formatter, folder_layout=folder_layout, savepath_in_metadict=savepath_in_metadict, - mapping_json_path= mapping_json_path, + mapping_json_path=mapping_json_path, ) def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): From 56d8df5756586782038d9d66e5b8f036310767e0 Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 3 Jun 2024 16:42:35 +0800 Subject: [PATCH 14/52] Fixes #7557 Fixes unseccessful ckecks. (if mapping_json_path is not None) Signed-off-by: staydelight --- monai/transforms/io/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cdcc4da80d..b8f8cc6ee0 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -440,7 +440,7 @@ def __init__( self.write_kwargs = {"verbose": print_log} self._data_index = 0 self.savepath_in_metadict = savepath_in_metadict - if mapping_json_path: + if mapping_json_path is not None: self.mapping_json_path = Path(mapping_json_path) self.savepath_in_metadict = True else: From 8bab11b6ea954b055b71a49c800395ab633426db Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 3 Jun 2024 16:53:34 +0800 Subject: [PATCH 15/52] Fixes #7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index b8f8cc6ee0..d4af4e91a6 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -513,7 +513,7 @@ def __call__( self._data_index += 1 if self.savepath_in_metadict and meta_data is not None: meta_data["saved_to"] = filename - if self.mapping_json_path and meta_data is not None: + if self.mapping_json_path is not None and meta_data is not None: log_data = [] log_data.append( {"input": meta_data.get("filename_or_obj", ()), "output": meta_data.get("saved_to", ())} From ca48feca4e3a8b95731ae32df0d227ef1fe06c11 Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 3 Jun 2024 17:58:29 +0800 Subject: [PATCH 16/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index d4af4e91a6..fee5de6270 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -440,11 +440,9 @@ def __init__( self.write_kwargs = {"verbose": print_log} self._data_index = 0 self.savepath_in_metadict = savepath_in_metadict + self.mapping_json_path = Path(mapping_json_path) if mapping_json_path is not None else None if mapping_json_path is not None: - self.mapping_json_path = Path(mapping_json_path) self.savepath_in_metadict = True - else: - self.mapping_json_path = None def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): """ From 117dd7855fd2cf39854b923ed15c0873cc5ac7da Mon Sep 17 00:00:00 2001 From: staydelight Date: Thu, 13 Jun 2024 20:48:38 +0800 Subject: [PATCH 17/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 64 +++++++++++++++++++++---------- monai/transforms/io/dictionary.py | 4 +- tests/test_mapping_json.py | 64 +++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 23 deletions(-) create mode 100644 tests/test_mapping_json.py diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index fee5de6270..480ab2b853 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -15,8 +15,8 @@ from __future__ import annotations import inspect -import json import logging +import json import sys import traceback import warnings @@ -394,7 +394,6 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_json_path: Path | str | None = None, ) -> None: self.folder_layout: FolderLayoutBase if folder_layout is None: @@ -440,9 +439,6 @@ def __init__( self.write_kwargs = {"verbose": print_log} self._data_index = 0 self.savepath_in_metadict = savepath_in_metadict - self.mapping_json_path = Path(mapping_json_path) if mapping_json_path is not None else None - if mapping_json_path is not None: - self.savepath_in_metadict = True def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): """ @@ -511,21 +507,6 @@ def __call__( self._data_index += 1 if self.savepath_in_metadict and meta_data is not None: meta_data["saved_to"] = filename - if self.mapping_json_path is not None and meta_data is not None: - log_data = [] - log_data.append( - {"input": meta_data.get("filename_or_obj", ()), "output": meta_data.get("saved_to", ())} - ) - - try: - with open(self.mapping_json_path) as f: - existing_log_data = json.load(f) - except FileNotFoundError: - existing_log_data = [] - - with open(self.mapping_json_path, "w") as f: - existing_log_data.extend(log_data) - json.dump(existing_log_data, f, indent=4) return img msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( @@ -534,3 +515,46 @@ def __call__( " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}" ) + +class MappingJson(Transform): + """ + Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. + + Args: + mapping_json_path (Path or str): Path to the JSON file where the mappings will be saved. + """ + + def __init__(self, mapping_json_path: Path | str = "mapping.json"): + self.mapping_json_path = Path(mapping_json_path) + + def write_json(self, input_path: str, output_path: str): + """ + Args: + input_path (str): The path of the input image file. + output_path (str): The path of the output image file. + """ + log_data = {"input": input_path, "output": output_path} + try: + with self.mapping_json_path.open("r") as f: + existing_log_data = json.load(f) + except FileNotFoundError: + existing_log_data = [] + + existing_log_data.append(log_data) + + with self.mapping_json_path.open("w") as f: + json.dump(existing_log_data, f, indent=4) + + def __call__(self, img: MetaTensor): + """ + Args: + img (MetaTensor): The input image with metadata. + """ + if "saved_to" not in img.meta: + raise KeyError("The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True.") + + + input_path = img.meta["filename_or_obj"] + output_path = img.meta["saved_to"] + self.write_json(input_path, output_path) + return img \ No newline at end of file diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 3cf46272c0..6cb01ccb19 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -281,7 +281,6 @@ def __init__( output_name_formatter: Callable[[dict, Transform], dict] | None = None, folder_layout: monai.data.FolderLayoutBase | None = None, savepath_in_metadict: bool = False, - mapping_json_path: Path | str | None = None, ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) @@ -305,7 +304,6 @@ def __init__( output_name_formatter=output_name_formatter, folder_layout=folder_layout, savepath_in_metadict=savepath_in_metadict, - mapping_json_path=mapping_json_path, ) def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): @@ -323,4 +321,4 @@ def __call__(self, data): LoadImageD = LoadImageDict = LoadImaged -SaveImageD = SaveImageDict = SaveImaged +SaveImageD = SaveImageDict = SaveImaged \ No newline at end of file diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py new file mode 100644 index 0000000000..07254ddc11 --- /dev/null +++ b/tests/test_mapping_json.py @@ -0,0 +1,64 @@ +import unittest +import json +import numpy as np +import tempfile +import nibabel as nib +import os + +from pathlib import Path +from monai.transforms import Compose, LoadImage, SaveImage +from monai.data.meta_tensor import MetaTensor +from parameterized import parameterized +from monai.transforms.io.array import MappingJson + +class TestMappingJson(unittest.TestCase): + def setUp(self): + self.mapping_json_path = "test_mapping.json" + if Path(self.mapping_json_path).exists(): + Path(self.mapping_json_path).unlink() + + TEST_CASE_1 = [{}, ["test_image.nii.gz"], (128, 128, 128), True] + TEST_CASE_2 = [{}, ["test_image.nii.gz"], (128, 128, 128), False] + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_metadict): + test_image = np.random.rand(128, 128, 128) + + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + file_path = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), file_path) + filenames[i] = file_path + + transforms = Compose([ + LoadImage(image_only=True, **load_params), + SaveImage(output_dir=tempdir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), + MappingJson(mapping_json_path=self.mapping_json_path) + ]) + + if savepath_in_metadict: + result = transforms(filenames[0]) + + img = result + meta = img.meta + + self.assertEqual(img.shape, expected_shape) + + self.assertTrue(Path(self.mapping_json_path).exists()) + with open(self.mapping_json_path, "r") as f: + mapping_data = json.load(f) + + expected_mapping = [{"input": meta["filename_or_obj"], "output": meta["saved_to"]}] + self.assertEqual(expected_mapping, mapping_data) + else: + with self.assertRaises(RuntimeError) as cm: + transforms(filenames[0]) + the_exception = cm.exception + self.assertIsInstance(the_exception.__cause__, KeyError) + self.assertIn( + "The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True.", + str(the_exception.__cause__) + ) + +if __name__ == "__main__": + unittest.main() From 3908cddc1ae78349de4c5d6ce1f4bd5907e30617 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Jun 2024 12:51:51 +0000 Subject: [PATCH 18/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/io/array.py | 6 +++--- monai/transforms/io/dictionary.py | 2 +- tests/test_mapping_json.py | 3 +-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 480ab2b853..d96196226b 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -515,7 +515,7 @@ def __call__( " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}" ) - + class MappingJson(Transform): """ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. @@ -523,7 +523,7 @@ class MappingJson(Transform): Args: mapping_json_path (Path or str): Path to the JSON file where the mappings will be saved. """ - + def __init__(self, mapping_json_path: Path | str = "mapping.json"): self.mapping_json_path = Path(mapping_json_path) @@ -557,4 +557,4 @@ def __call__(self, img: MetaTensor): input_path = img.meta["filename_or_obj"] output_path = img.meta["saved_to"] self.write_json(input_path, output_path) - return img \ No newline at end of file + return img diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 6cb01ccb19..4da1d422ca 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -321,4 +321,4 @@ def __call__(self, data): LoadImageD = LoadImageDict = LoadImaged -SaveImageD = SaveImageDict = SaveImaged \ No newline at end of file +SaveImageD = SaveImageDict = SaveImaged diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 07254ddc11..0b4b752351 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -7,7 +7,6 @@ from pathlib import Path from monai.transforms import Compose, LoadImage, SaveImage -from monai.data.meta_tensor import MetaTensor from parameterized import parameterized from monai.transforms.io.array import MappingJson @@ -45,7 +44,7 @@ def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_ self.assertEqual(img.shape, expected_shape) self.assertTrue(Path(self.mapping_json_path).exists()) - with open(self.mapping_json_path, "r") as f: + with open(self.mapping_json_path) as f: mapping_data = json.load(f) expected_mapping = [{"input": meta["filename_or_obj"], "output": meta["saved_to"]}] From 36e5af09aa2180829b882ed10f3d8369a146caf6 Mon Sep 17 00:00:00 2001 From: staydelight Date: Thu, 13 Jun 2024 21:29:18 +0800 Subject: [PATCH 19/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 8 ++++--- tests/test_mapping_json.py | 43 ++++++++++++++++++++++++++---------- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index d96196226b..0badb23528 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -15,8 +15,8 @@ from __future__ import annotations import inspect -import logging import json +import logging import sys import traceback import warnings @@ -516,6 +516,7 @@ def __call__( f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}" ) + class MappingJson(Transform): """ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. @@ -551,8 +552,9 @@ def __call__(self, img: MetaTensor): img (MetaTensor): The input image with metadata. """ if "saved_to" not in img.meta: - raise KeyError("The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True.") - + raise KeyError( + "The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True." + ) input_path = img.meta["filename_or_obj"] output_path = img.meta["saved_to"] diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 0b4b752351..99992c6d1b 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -1,15 +1,31 @@ -import unittest +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + import json -import numpy as np +import os import tempfile +import unittest +from pathlib import Path + import nibabel as nib -import os +import numpy as np +from parameterized import parameterized -from pathlib import Path +from monai.data.meta_tensor import MetaTensor from monai.transforms import Compose, LoadImage, SaveImage -from parameterized import parameterized from monai.transforms.io.array import MappingJson + class TestMappingJson(unittest.TestCase): def setUp(self): self.mapping_json_path = "test_mapping.json" @@ -29,11 +45,13 @@ def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_ nib.save(nib.Nifti1Image(test_image, np.eye(4)), file_path) filenames[i] = file_path - transforms = Compose([ - LoadImage(image_only=True, **load_params), - SaveImage(output_dir=tempdir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), - MappingJson(mapping_json_path=self.mapping_json_path) - ]) + transforms = Compose( + [ + LoadImage(image_only=True, **load_params), + SaveImage(output_dir=tempdir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), + MappingJson(mapping_json_path=self.mapping_json_path), + ] + ) if savepath_in_metadict: result = transforms(filenames[0]) @@ -44,7 +62,7 @@ def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_ self.assertEqual(img.shape, expected_shape) self.assertTrue(Path(self.mapping_json_path).exists()) - with open(self.mapping_json_path) as f: + with open(self.mapping_json_path, "r") as f: mapping_data = json.load(f) expected_mapping = [{"input": meta["filename_or_obj"], "output": meta["saved_to"]}] @@ -56,8 +74,9 @@ def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_ self.assertIsInstance(the_exception.__cause__, KeyError) self.assertIn( "The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True.", - str(the_exception.__cause__) + str(the_exception.__cause__), ) + if __name__ == "__main__": unittest.main() From 1a3da3874666c3b1cc6f18a6f5901b514cd6a975 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Jun 2024 13:32:08 +0000 Subject: [PATCH 20/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_json.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 99992c6d1b..0355b2c789 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -21,7 +21,6 @@ import numpy as np from parameterized import parameterized -from monai.data.meta_tensor import MetaTensor from monai.transforms import Compose, LoadImage, SaveImage from monai.transforms.io.array import MappingJson @@ -62,7 +61,7 @@ def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_ self.assertEqual(img.shape, expected_shape) self.assertTrue(Path(self.mapping_json_path).exists()) - with open(self.mapping_json_path, "r") as f: + with open(self.mapping_json_path) as f: mapping_data = json.load(f) expected_mapping = [{"input": meta["filename_or_obj"], "output": meta["saved_to"]}] From 37d19eddd2aeccd9057feb1a6fc8e2497f16ac5d Mon Sep 17 00:00:00 2001 From: staydelight Date: Thu, 13 Jun 2024 22:34:27 +0800 Subject: [PATCH 21/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 4 +--- tests/test_mapping_json.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 0badb23528..e58ceaf825 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -552,9 +552,7 @@ def __call__(self, img: MetaTensor): img (MetaTensor): The input image with metadata. """ if "saved_to" not in img.meta: - raise KeyError( - "The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True." - ) + raise KeyError("Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.") input_path = img.meta["filename_or_obj"] output_path = img.meta["saved_to"] diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 0355b2c789..7ab6820d4d 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -72,7 +72,7 @@ def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_ the_exception = cm.exception self.assertIsInstance(the_exception.__cause__, KeyError) self.assertIn( - "The 'saved_to' key is missing from the image metadata. Ensure SaveImage is configured with savepath_in_metadict=True.", + "Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.", str(the_exception.__cause__), ) From cff29264e8ff38157fe59b4a9b8777a1b7158cbc Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 15 Jul 2024 16:43:15 +0800 Subject: [PATCH 22/52] fix-issue-7557 Signed-off-by: staydelight --- tests/test_mapping_json.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 7ab6820d4d..92f09a6f93 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -21,6 +21,7 @@ import numpy as np from parameterized import parameterized +from monai.data import NibabelReader from monai.transforms import Compose, LoadImage, SaveImage from monai.transforms.io.array import MappingJson From 33c078b56df2bd3a82f49574c45c9e31c50610a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Jul 2024 08:48:34 +0000 Subject: [PATCH 23/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_json.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 92f09a6f93..7ab6820d4d 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -21,7 +21,6 @@ import numpy as np from parameterized import parameterized -from monai.data import NibabelReader from monai.transforms import Compose, LoadImage, SaveImage from monai.transforms.io.array import MappingJson From 40b3e2197fa8b74d772b5e65c9914b3ae3d8105f Mon Sep 17 00:00:00 2001 From: staydelight Date: Sun, 21 Jul 2024 22:53:48 +0800 Subject: [PATCH 24/52] fix-issue-7557 Signed-off-by: staydelight --- tests/test_mapping_json.py | 107 ++++++++++++++++++++----------------- 1 file changed, 57 insertions(+), 50 deletions(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 7ab6820d4d..f787db6451 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -13,69 +13,76 @@ import json import os +import shutil import tempfile import unittest from pathlib import Path -import nibabel as nib import numpy as np +import torch from parameterized import parameterized -from monai.transforms import Compose, LoadImage, SaveImage -from monai.transforms.io.array import MappingJson +from monai.tests.utils import TEST_NDARRAYS, make_nifti_image +from monai.transforms import Compose, LoadImage, MappingJson, SaveImage + +TESTS = [] +for p in TEST_NDARRAYS: + for q in TEST_NDARRAYS: + TEST_IMAGE = p(np.arange(24).reshape((2, 4, 3))) + TEST_AFFINE = q( + np.array( + [[-5.3, 0.0, 0.0, 102.01], [0.0, 0.52, 2.17, -7.50], [-0.0, 1.98, -0.26, -23.12], [0.0, 0.0, 0.0, 1.0]] + ) + ) + TESTS.append([TEST_IMAGE, TEST_AFFINE, True]) + TESTS.append([TEST_IMAGE, TEST_AFFINE, False]) class TestMappingJson(unittest.TestCase): def setUp(self): - self.mapping_json_path = "test_mapping.json" - if Path(self.mapping_json_path).exists(): - Path(self.mapping_json_path).unlink() - - TEST_CASE_1 = [{}, ["test_image.nii.gz"], (128, 128, 128), True] - TEST_CASE_2 = [{}, ["test_image.nii.gz"], (128, 128, 128), False] - - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) - def test_mapping_json(self, load_params, filenames, expected_shape, savepath_in_metadict): - test_image = np.random.rand(128, 128, 128) - - with tempfile.TemporaryDirectory() as tempdir: - for i, name in enumerate(filenames): - file_path = os.path.join(tempdir, name) - nib.save(nib.Nifti1Image(test_image, np.eye(4)), file_path) - filenames[i] = file_path - - transforms = Compose( - [ - LoadImage(image_only=True, **load_params), - SaveImage(output_dir=tempdir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), - MappingJson(mapping_json_path=self.mapping_json_path), - ] + self.test_dir = tempfile.mkdtemp() + self.mapping_json_path = os.path.join(self.test_dir, "mapping.json") + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + @parameterized.expand(TESTS) + def test_mapping_json(self, array, affine, savepath_in_metadict): + name = "test_image" + output_ext = ".nii.gz" + test_image_name = make_nifti_image(array, affine, fname=os.path.join(self.test_dir, name)) + + input_file = os.path.join(self.test_dir, test_image_name) + output_file = os.path.join(self.test_dir, name, name + "_trans" + output_ext) + + transforms = Compose( + [ + LoadImage(reader="NibabelReader", image_only=True), + SaveImage(output_dir=self.test_dir, output_ext=output_ext, savepath_in_metadict=savepath_in_metadict), + MappingJson(mapping_json_path=self.mapping_json_path), + ] + ) + + if savepath_in_metadict: + transforms(input_file) + self.assertTrue(Path(self.mapping_json_path).exists()) + with open(self.mapping_json_path, "r") as f: + mapping_data = json.load(f) + + self.assertEqual(len(mapping_data), 1) + self.assertEqual(mapping_data[0]["input"], input_file) + self.assertEqual(mapping_data[0]["output"], output_file) + else: + with self.assertRaises(RuntimeError) as cm: + transforms(input_file) + the_exception = cm.exception + cause_exception = the_exception.__cause__ + + self.assertIsInstance(cause_exception, KeyError) + self.assertIn( + "Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.", str(cause_exception) ) - if savepath_in_metadict: - result = transforms(filenames[0]) - - img = result - meta = img.meta - - self.assertEqual(img.shape, expected_shape) - - self.assertTrue(Path(self.mapping_json_path).exists()) - with open(self.mapping_json_path) as f: - mapping_data = json.load(f) - - expected_mapping = [{"input": meta["filename_or_obj"], "output": meta["saved_to"]}] - self.assertEqual(expected_mapping, mapping_data) - else: - with self.assertRaises(RuntimeError) as cm: - transforms(filenames[0]) - the_exception = cm.exception - self.assertIsInstance(the_exception.__cause__, KeyError) - self.assertIn( - "Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.", - str(the_exception.__cause__), - ) - if __name__ == "__main__": unittest.main() From 36047a2e55152ef18d79c6436b4bdbc199dc7f6e Mon Sep 17 00:00:00 2001 From: staydelight Date: Sun, 21 Jul 2024 23:28:14 +0800 Subject: [PATCH 25/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ef1da2d855..bf15a74e04 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -238,7 +238,7 @@ ) from .inverse import InvertibleTransform, TraceableTransform from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict -from .io.array import SUPPORTED_READERS, LoadImage, SaveImage +from .io.array import SUPPORTED_READERS, LoadImage, MappingJson, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict From b3852880990d7a46edbbb2526985c2f395023d2b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 21 Jul 2024 15:34:08 +0000 Subject: [PATCH 26/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_json.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index f787db6451..356ddd067d 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -19,7 +19,6 @@ from pathlib import Path import numpy as np -import torch from parameterized import parameterized from monai.tests.utils import TEST_NDARRAYS, make_nifti_image @@ -66,7 +65,7 @@ def test_mapping_json(self, array, affine, savepath_in_metadict): if savepath_in_metadict: transforms(input_file) self.assertTrue(Path(self.mapping_json_path).exists()) - with open(self.mapping_json_path, "r") as f: + with open(self.mapping_json_path) as f: mapping_data = json.load(f) self.assertEqual(len(mapping_data), 1) From 393744881c990b8fafa70720da41ea5a76bcc89f Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 22 Jul 2024 00:02:12 +0800 Subject: [PATCH 27/52] fix-issue-7557 Signed-off-by: staydelight --- tests/test_mapping_json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 356ddd067d..b0dfc97fe0 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -21,7 +21,7 @@ import numpy as np from parameterized import parameterized -from monai.tests.utils import TEST_NDARRAYS, make_nifti_image +from tests.utils import TEST_NDARRAYS, make_nifti_image from monai.transforms import Compose, LoadImage, MappingJson, SaveImage TESTS = [] From 44307fc84deab974302a9bebadf7dea8589e627c Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 22 Jul 2024 21:06:41 +0800 Subject: [PATCH 28/52] fix-issue-7557 Signed-off-by: staydelight --- tests/test_mapping_json.py | 45 +++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index b0dfc97fe0..0a31a10682 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -16,55 +16,50 @@ import shutil import tempfile import unittest -from pathlib import Path import numpy as np +import torch from parameterized import parameterized -from tests.utils import TEST_NDARRAYS, make_nifti_image +from monai.data import NibabelWriter from monai.transforms import Compose, LoadImage, MappingJson, SaveImage -TESTS = [] -for p in TEST_NDARRAYS: - for q in TEST_NDARRAYS: - TEST_IMAGE = p(np.arange(24).reshape((2, 4, 3))) - TEST_AFFINE = q( - np.array( - [[-5.3, 0.0, 0.0, 102.01], [0.0, 0.52, 2.17, -7.50], [-0.0, 1.98, -0.26, -23.12], [0.0, 0.0, 0.0, 1.0]] - ) - ) - TESTS.append([TEST_IMAGE, TEST_AFFINE, True]) - TESTS.append([TEST_IMAGE, TEST_AFFINE, False]) - class TestMappingJson(unittest.TestCase): def setUp(self): - self.test_dir = tempfile.mkdtemp() - self.mapping_json_path = os.path.join(self.test_dir, "mapping.json") + self.temp_dir = tempfile.TemporaryDirectory() + self.mapping_json_path = os.path.join(self.temp_dir.name, "mapping.json") def tearDown(self): - shutil.rmtree(self.test_dir, ignore_errors=True) + self.temp_dir.cleanup() - @parameterized.expand(TESTS) - def test_mapping_json(self, array, affine, savepath_in_metadict): - name = "test_image" + @parameterized.expand([(True,), (False,)]) + def test_mapping_json(self, savepath_in_metadict): + image_data = np.arange(48, dtype=np.uint8).reshape(1, 2, 3, 8) output_ext = ".nii.gz" - test_image_name = make_nifti_image(array, affine, fname=os.path.join(self.test_dir, name)) + name = "test_image" + + input_file = os.path.join(self.temp_dir.name, name + output_ext) + output_file = os.path.join(self.temp_dir.name, name, name + "_trans" + output_ext) - input_file = os.path.join(self.test_dir, test_image_name) - output_file = os.path.join(self.test_dir, name, name + "_trans" + output_ext) + writer = NibabelWriter() + writer.set_data_array(image_data, channel_dim=None) + writer.set_metadata({"affine": np.eye(4), "original_affine": np.eye(4)}) + writer.write(input_file) transforms = Compose( [ LoadImage(reader="NibabelReader", image_only=True), - SaveImage(output_dir=self.test_dir, output_ext=output_ext, savepath_in_metadict=savepath_in_metadict), + SaveImage( + output_dir=self.temp_dir.name, output_ext=output_ext, savepath_in_metadict=savepath_in_metadict + ), MappingJson(mapping_json_path=self.mapping_json_path), ] ) if savepath_in_metadict: transforms(input_file) - self.assertTrue(Path(self.mapping_json_path).exists()) + self.assertTrue(os.path.exists(self.mapping_json_path)) with open(self.mapping_json_path) as f: mapping_data = json.load(f) From cdf4a1bd1819be48fcb9846e6fde5fa2eb81d0a5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 13:10:45 +0000 Subject: [PATCH 29/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_json.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 0a31a10682..2b0c854d8a 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -13,12 +13,10 @@ import json import os -import shutil import tempfile import unittest import numpy as np -import torch from parameterized import parameterized from monai.data import NibabelWriter From 5dd268e3c0496a0bacdf4cecc762557b2e3d5e30 Mon Sep 17 00:00:00 2001 From: staydelight Date: Thu, 25 Jul 2024 12:35:31 +0800 Subject: [PATCH 30/52] fix-issue-7557 Signed-off-by: staydelight --- tests/test_mapping_json.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 2b0c854d8a..6f7fd740ef 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -15,15 +15,19 @@ import os import tempfile import unittest +from pathlib import Path import numpy as np from parameterized import parameterized -from monai.data import NibabelWriter from monai.transforms import Compose, LoadImage, MappingJson, SaveImage +from monai.utils import optional_import +nib, has_nib = optional_import("nibabel") -class TestMappingJson(unittest.TestCase): + +@unittest.skipUnless(has_nib, "nibabel required") +class TestMappingJsonD(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.mapping_json_path = os.path.join(self.temp_dir.name, "mapping.json") @@ -32,22 +36,19 @@ def tearDown(self): self.temp_dir.cleanup() @parameterized.expand([(True,), (False,)]) - def test_mapping_json(self, savepath_in_metadict): - image_data = np.arange(48, dtype=np.uint8).reshape(1, 2, 3, 8) + def test_mapping_jsond(self, savepath_in_metadict): + test_image = np.random.rand(128, 128, 128) output_ext = ".nii.gz" name = "test_image" input_file = os.path.join(self.temp_dir.name, name + output_ext) output_file = os.path.join(self.temp_dir.name, name, name + "_trans" + output_ext) - writer = NibabelWriter() - writer.set_data_array(image_data, channel_dim=None) - writer.set_metadata({"affine": np.eye(4), "original_affine": np.eye(4)}) - writer.write(input_file) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file) transforms = Compose( [ - LoadImage(reader="NibabelReader", image_only=True), + LoadImage(image_only=True), SaveImage( output_dir=self.temp_dir.name, output_ext=output_ext, savepath_in_metadict=savepath_in_metadict ), From 401557a36ca3e830aadbd18e86086070dff77896 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jul 2024 04:39:21 +0000 Subject: [PATCH 31/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_json.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 6f7fd740ef..898c828d32 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -15,7 +15,6 @@ import os import tempfile import unittest -from pathlib import Path import numpy as np from parameterized import parameterized From 4adb87d104c5c8ab263e22c451d7127cd00a4658 Mon Sep 17 00:00:00 2001 From: staydelight Date: Wed, 7 Aug 2024 00:34:57 +0800 Subject: [PATCH 32/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 22 +++++---- tests/test_mapping_json.py | 86 +++++++++++++++++++++++++----------- 2 files changed, 72 insertions(+), 36 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cf71bd1926..bc8dd9e109 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -27,6 +27,7 @@ import numpy as np import torch +from filelock import FileLock from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data import image_writer @@ -520,6 +521,7 @@ def __call__( class MappingJson(Transform): """ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. + This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment. Args: mapping_json_path (Path or str): Path to the JSON file where the mappings will be saved. @@ -527,6 +529,7 @@ class MappingJson(Transform): def __init__(self, mapping_json_path: Path | str = "mapping.json"): self.mapping_json_path = Path(mapping_json_path) + self.lock = FileLock(str(self.mapping_json_path) + ".lock") def write_json(self, input_path: str, output_path: str): """ @@ -535,16 +538,18 @@ def write_json(self, input_path: str, output_path: str): output_path (str): The path of the output image file. """ log_data = {"input": input_path, "output": output_path} - try: - with self.mapping_json_path.open("r") as f: - existing_log_data = json.load(f) - except FileNotFoundError: - existing_log_data = [] - existing_log_data.append(log_data) + with self.lock: + try: + with self.mapping_json_path.open("r") as f: + existing_log_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + existing_log_data = [] + + existing_log_data.append(log_data) - with self.mapping_json_path.open("w") as f: - json.dump(existing_log_data, f, indent=4) + with self.mapping_json_path.open("w") as f: + json.dump(existing_log_data, f, indent=4) def __call__(self, img: MetaTensor): """ @@ -553,7 +558,6 @@ def __call__(self, img: MetaTensor): """ if "saved_to" not in img.meta: raise KeyError("Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.") - input_path = img.meta["filename_or_obj"] output_path = img.meta["saved_to"] self.write_json(input_path, output_path) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 898c828d32..9deb79ce31 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -12,8 +12,12 @@ from __future__ import annotations import json +import multiprocessing import os +import random +import shutil import tempfile +import time import unittest import numpy as np @@ -25,56 +29,84 @@ nib, has_nib = optional_import("nibabel") +def create_input_file(temp_dir, name): + test_image = np.random.rand(128, 128, 128) + output_ext = ".nii.gz" + input_file = os.path.join(temp_dir, name + output_ext) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file) + return input_file + + +def create_transform(temp_dir, mapping_json_path, savepath_in_metadict=True): + return Compose( + [ + LoadImage(image_only=True), + SaveImage(output_dir=temp_dir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), + MappingJson(mapping_json_path=mapping_json_path), + ] + ) + + +def process_image(args): + temp_dir, mapping_json_path, i = args + time.sleep(random.uniform(0, 0.1)) + input_file = create_input_file(temp_dir, f"test_image_{i}") + transform = create_transform(temp_dir, mapping_json_path) + transform(input_file) + time.sleep(random.uniform(0, 0.1)) + + @unittest.skipUnless(has_nib, "nibabel required") -class TestMappingJsonD(unittest.TestCase): +class TestMappingJson(unittest.TestCase): def setUp(self): - self.temp_dir = tempfile.TemporaryDirectory() - self.mapping_json_path = os.path.join(self.temp_dir.name, "mapping.json") + self.temp_dir = tempfile.mkdtemp() + self.mapping_json_path = os.path.join(self.temp_dir, "mapping.json") def tearDown(self): - self.temp_dir.cleanup() + shutil.rmtree(self.temp_dir) @parameterized.expand([(True,), (False,)]) - def test_mapping_jsond(self, savepath_in_metadict): - test_image = np.random.rand(128, 128, 128) - output_ext = ".nii.gz" + def test_mapping_json(self, savepath_in_metadict): name = "test_image" + input_file = create_input_file(self.temp_dir, name) + output_file = os.path.join(self.temp_dir, name, name + "_trans.nii.gz") - input_file = os.path.join(self.temp_dir.name, name + output_ext) - output_file = os.path.join(self.temp_dir.name, name, name + "_trans" + output_ext) - - nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file) - - transforms = Compose( - [ - LoadImage(image_only=True), - SaveImage( - output_dir=self.temp_dir.name, output_ext=output_ext, savepath_in_metadict=savepath_in_metadict - ), - MappingJson(mapping_json_path=self.mapping_json_path), - ] - ) + transform = create_transform(self.temp_dir, self.mapping_json_path, savepath_in_metadict) if savepath_in_metadict: - transforms(input_file) + transform(input_file) self.assertTrue(os.path.exists(self.mapping_json_path)) with open(self.mapping_json_path) as f: mapping_data = json.load(f) - self.assertEqual(len(mapping_data), 1) self.assertEqual(mapping_data[0]["input"], input_file) self.assertEqual(mapping_data[0]["output"], output_file) else: with self.assertRaises(RuntimeError) as cm: - transforms(input_file) - the_exception = cm.exception - cause_exception = the_exception.__cause__ - + transform(input_file) + cause_exception = cm.exception.__cause__ self.assertIsInstance(cause_exception, KeyError) self.assertIn( "Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.", str(cause_exception) ) + def test_multiprocess_mapping_json(self): + num_processes, num_images = 16, 1000 + + with multiprocessing.Pool(processes=num_processes) as pool: + args = [(self.temp_dir, self.mapping_json_path, i) for i in range(num_images)] + pool.map(process_image, args) + + with open(self.mapping_json_path) as f: + mapping_data = json.load(f) + + self.assertEqual(len(mapping_data), num_images, f"Expected {num_images} entries, but got {len(mapping_data)}") + unique_entries = set(tuple(sorted(entry.items())) for entry in mapping_data) + self.assertEqual(len(mapping_data), len(unique_entries), "Duplicate entries exist") + for entry in mapping_data: + self.assertIn("input", entry, "Entry missing 'input' key") + self.assertIn("output", entry, "Entry missing 'output' key") + if __name__ == "__main__": unittest.main() From 0c14f4fc2ac6b2852e60e33776a5d9bb48791dee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Aug 2024 16:36:42 +0000 Subject: [PATCH 33/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 9deb79ce31..2e57c0b360 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -101,7 +101,7 @@ def test_multiprocess_mapping_json(self): mapping_data = json.load(f) self.assertEqual(len(mapping_data), num_images, f"Expected {num_images} entries, but got {len(mapping_data)}") - unique_entries = set(tuple(sorted(entry.items())) for entry in mapping_data) + unique_entries = {tuple(sorted(entry.items())) for entry in mapping_data} self.assertEqual(len(mapping_data), len(unique_entries), "Duplicate entries exist") for entry in mapping_data: self.assertIn("input", entry, "Entry missing 'input' key") From 8fafc050d59253c2dd0594360239b5e5dc9a654c Mon Sep 17 00:00:00 2001 From: staydelight Date: Wed, 7 Aug 2024 01:51:54 +0800 Subject: [PATCH 34/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 2 +- tests/test_mapping_json.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index bc8dd9e109..cc45e791fc 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -27,7 +27,6 @@ import numpy as np import torch -from filelock import FileLock from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data import image_writer @@ -52,6 +51,7 @@ nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") nrrd, _ = optional_import("nrrd") +FileLock, _ = optional_import("filelock") __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"] diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index 2e57c0b360..aaaccf1a0e 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -91,7 +91,7 @@ def test_mapping_json(self, savepath_in_metadict): ) def test_multiprocess_mapping_json(self): - num_processes, num_images = 16, 1000 + num_processes, num_images = 8, 300 with multiprocessing.Pool(processes=num_processes) as pool: args = [(self.temp_dir, self.mapping_json_path, i) for i in range(num_images)] From b238987d28443d367f8e0a524e961ba6e0efcc38 Mon Sep 17 00:00:00 2001 From: staydelight Date: Wed, 7 Aug 2024 02:35:10 +0800 Subject: [PATCH 35/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cc45e791fc..e6a92c522d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -51,7 +51,7 @@ nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") nrrd, _ = optional_import("nrrd") -FileLock, _ = optional_import("filelock") +FileLock, _ = optional_import("filelock", name="FileLock") __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"] From 5c7599057e4bd3cc46374dc3fcee645100fb7f70 Mon Sep 17 00:00:00 2001 From: staydelight Date: Wed, 7 Aug 2024 13:30:19 +0800 Subject: [PATCH 36/52] fix-issue-7557 Signed-off-by: staydelight --- tests/test_mapping_json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mapping_json.py b/tests/test_mapping_json.py index aaaccf1a0e..212bd80903 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_json.py @@ -91,7 +91,7 @@ def test_mapping_json(self, savepath_in_metadict): ) def test_multiprocess_mapping_json(self): - num_processes, num_images = 8, 300 + num_processes, num_images = 3, 50 with multiprocessing.Pool(processes=num_processes) as pool: args = [(self.temp_dir, self.mapping_json_path, i) for i in range(num_images)] From f7deb8635ee43c5d759845f4a4f4ef2fbaecf302 Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 19 Aug 2024 15:39:48 +0800 Subject: [PATCH 37/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/__init__.py | 2 +- monai/transforms/io/array.py | 48 +++++------ monai/utils/enums.py | 1 + ...t_mapping_json.py => test_mapping_file.py} | 81 ++++++++++--------- 4 files changed, 66 insertions(+), 66 deletions(-) rename tests/{test_mapping_json.py => test_mapping_file.py} (51%) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 63f44d04c1..bcafcb753d 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -238,7 +238,7 @@ ) from .inverse import InvertibleTransform, TraceableTransform from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict -from .io.array import SUPPORTED_READERS, LoadImage, MappingJson, SaveImage +from .io.array import SUPPORTED_READERS, LoadImage, WriteFileMapping, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index e6a92c522d..bfd6f32c8f 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -46,6 +46,7 @@ from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSamplePadMode from monai.utils import ImageMetaKey as Key +from monai.utils import MetaKeys from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import nib, _ = optional_import("nibabel") @@ -507,7 +508,7 @@ def __call__( else: self._data_index += 1 if self.savepath_in_metadict and meta_data is not None: - meta_data["saved_to"] = filename + meta_data[MetaKeys.SAVED_TO] = filename return img msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( @@ -518,47 +519,40 @@ def __call__( ) -class MappingJson(Transform): +class WriteFileMapping(Transform): """ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment. - + Args: - mapping_json_path (Path or str): Path to the JSON file where the mappings will be saved. + mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved. """ + def __init__(self, mapping_file_path: Path | str = "mapping.json"): + self.mapping_file_path = Path(mapping_file_path) + self.lock = FileLock(str(self.mapping_file_path) + ".lock") - def __init__(self, mapping_json_path: Path | str = "mapping.json"): - self.mapping_json_path = Path(mapping_json_path) - self.lock = FileLock(str(self.mapping_json_path) + ".lock") - - def write_json(self, input_path: str, output_path: str): + def __call__(self, img: MetaTensor): """ Args: - input_path (str): The path of the input image file. - output_path (str): The path of the output image file. + img (MetaTensor): The input image with metadata. """ + if MetaKeys.SAVED_TO not in img.meta: + raise KeyError("Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.") + + input_path = img.meta[Key.FILENAME_OR_OBJ] + output_path = img.meta[MetaKeys.SAVED_TO] log_data = {"input": input_path, "output": output_path} - + with self.lock: try: - with self.mapping_json_path.open("r") as f: + with self.mapping_file_path.open("r") as f: existing_log_data = json.load(f) except (FileNotFoundError, json.JSONDecodeError): existing_log_data = [] - + existing_log_data.append(log_data) - - with self.mapping_json_path.open("w") as f: + + with self.mapping_file_path.open("w") as f: json.dump(existing_log_data, f, indent=4) - - def __call__(self, img: MetaTensor): - """ - Args: - img (MetaTensor): The input image with metadata. - """ - if "saved_to" not in img.meta: - raise KeyError("Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.") - input_path = img.meta["filename_or_obj"] - output_path = img.meta["saved_to"] - self.write_json(input_path, output_path) + return img diff --git a/monai/utils/enums.py b/monai/utils/enums.py index b786e92151..eba1be18ed 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -543,6 +543,7 @@ class MetaKeys(StrEnum): SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension SPACE = "space" # possible values of space type are defined in `SpaceKeys` ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan") + SAVED_TO = "saved_to" class ColorOrder(StrEnum): diff --git a/tests/test_mapping_json.py b/tests/test_mapping_file.py similarity index 51% rename from tests/test_mapping_json.py rename to tests/test_mapping_file.py index 212bd80903..ca71dcb72a 100644 --- a/tests/test_mapping_json.py +++ b/tests/test_mapping_file.py @@ -12,7 +12,6 @@ from __future__ import annotations import json -import multiprocessing import os import random import shutil @@ -23,7 +22,8 @@ import numpy as np from parameterized import parameterized -from monai.transforms import Compose, LoadImage, MappingJson, SaveImage +from monai.data import Dataset, DataLoader +from monai.transforms import Compose, LoadImage, WriteFileMapping, SaveImage from monai.utils import optional_import nib, has_nib = optional_import("nibabel") @@ -37,46 +37,37 @@ def create_input_file(temp_dir, name): return input_file -def create_transform(temp_dir, mapping_json_path, savepath_in_metadict=True): +def create_transform(temp_dir, mapping_file_path, savepath_in_metadict=True): return Compose( [ LoadImage(image_only=True), SaveImage(output_dir=temp_dir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict), - MappingJson(mapping_json_path=mapping_json_path), + WriteFileMapping(mapping_file_path=mapping_file_path), ] ) -def process_image(args): - temp_dir, mapping_json_path, i = args - time.sleep(random.uniform(0, 0.1)) - input_file = create_input_file(temp_dir, f"test_image_{i}") - transform = create_transform(temp_dir, mapping_json_path) - transform(input_file) - time.sleep(random.uniform(0, 0.1)) - - @unittest.skipUnless(has_nib, "nibabel required") -class TestMappingJson(unittest.TestCase): +class TestWriteFileMapping(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.mkdtemp() - self.mapping_json_path = os.path.join(self.temp_dir, "mapping.json") def tearDown(self): shutil.rmtree(self.temp_dir) @parameterized.expand([(True,), (False,)]) - def test_mapping_json(self, savepath_in_metadict): + def test_mapping_file(self, savepath_in_metadict): + mapping_file_path = os.path.join(self.temp_dir, "mapping.json") name = "test_image" input_file = create_input_file(self.temp_dir, name) output_file = os.path.join(self.temp_dir, name, name + "_trans.nii.gz") - transform = create_transform(self.temp_dir, self.mapping_json_path, savepath_in_metadict) + transform = create_transform(self.temp_dir, mapping_file_path, savepath_in_metadict) if savepath_in_metadict: transform(input_file) - self.assertTrue(os.path.exists(self.mapping_json_path)) - with open(self.mapping_json_path) as f: + self.assertTrue(os.path.exists(mapping_file_path)) + with open(mapping_file_path) as f: mapping_data = json.load(f) self.assertEqual(len(mapping_data), 1) self.assertEqual(mapping_data[0]["input"], input_file) @@ -87,26 +78,40 @@ def test_mapping_json(self, savepath_in_metadict): cause_exception = cm.exception.__cause__ self.assertIsInstance(cause_exception, KeyError) self.assertIn( - "Missing 'saved_to' key in metadata. Check SaveImage savepath_in_metadict.", str(cause_exception) + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", str(cause_exception) ) - def test_multiprocess_mapping_json(self): - num_processes, num_images = 3, 50 - - with multiprocessing.Pool(processes=num_processes) as pool: - args = [(self.temp_dir, self.mapping_json_path, i) for i in range(num_images)] - pool.map(process_image, args) - - with open(self.mapping_json_path) as f: - mapping_data = json.load(f) - - self.assertEqual(len(mapping_data), num_images, f"Expected {num_images} entries, but got {len(mapping_data)}") - unique_entries = {tuple(sorted(entry.items())) for entry in mapping_data} - self.assertEqual(len(mapping_data), len(unique_entries), "Duplicate entries exist") - for entry in mapping_data: - self.assertIn("input", entry, "Entry missing 'input' key") - self.assertIn("output", entry, "Entry missing 'output' key") - + def test_multiprocess_mapping_file(self): + num_images = 50 + + single_mapping_file = os.path.join(self.temp_dir, "single_mapping.json") + multi_mapping_file = os.path.join(self.temp_dir, "multi_mapping.json") + + data = [create_input_file(self.temp_dir, f"test_image_{i}") for i in range(num_images)] + + # single process + single_transform = create_transform(self.temp_dir, single_mapping_file) + single_dataset = Dataset(data=data, transform=single_transform) + single_loader = DataLoader(single_dataset, batch_size=1, num_workers=0, shuffle=True) + for _ in single_loader: + pass + + # multiple processes + multi_transform = create_transform(self.temp_dir, multi_mapping_file) + multi_dataset = Dataset(data=data, transform=multi_transform) + multi_loader = DataLoader(multi_dataset, batch_size=2, num_workers=2, shuffle=True) + for _ in multi_loader: + pass + + with open(single_mapping_file) as f: + single_mapping_data = json.load(f) + with open(multi_mapping_file) as f: + multi_mapping_data = json.load(f) + + single_set = set((entry['input'], entry['output']) for entry in single_mapping_data) + multi_set = set((entry['input'], entry['output']) for entry in multi_mapping_data) + + self.assertEqual(single_set, multi_set) if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file From 3607b1829174e31c633e30501339680ad73e7b5a Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 19 Aug 2024 15:57:56 +0800 Subject: [PATCH 38/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/__init__.py | 2 +- monai/transforms/io/array.py | 27 ++++++++++++++++++--------- tests/test_mapping_file.py | 28 +++++++++++++++------------- 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index bcafcb753d..69d4426c57 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -238,7 +238,7 @@ ) from .inverse import InvertibleTransform, TraceableTransform from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict -from .io.array import SUPPORTED_READERS, LoadImage, WriteFileMapping, SaveImage +from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index bfd6f32c8f..d63366f1a2 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -46,8 +46,14 @@ from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSamplePadMode from monai.utils import ImageMetaKey as Key -from monai.utils import MetaKeys -from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import +from monai.utils import ( + MetaKeys, + OptionalImportError, + convert_to_dst_type, + ensure_tuple, + look_up_option, + optional_import, +) nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -523,10 +529,11 @@ class WriteFileMapping(Transform): """ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths. This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment. - + Args: mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved. """ + def __init__(self, mapping_file_path: Path | str = "mapping.json"): self.mapping_file_path = Path(mapping_file_path) self.lock = FileLock(str(self.mapping_file_path) + ".lock") @@ -537,22 +544,24 @@ def __call__(self, img: MetaTensor): img (MetaTensor): The input image with metadata. """ if MetaKeys.SAVED_TO not in img.meta: - raise KeyError("Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.") - + raise KeyError( + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True." + ) + input_path = img.meta[Key.FILENAME_OR_OBJ] output_path = img.meta[MetaKeys.SAVED_TO] log_data = {"input": input_path, "output": output_path} - + with self.lock: try: with self.mapping_file_path.open("r") as f: existing_log_data = json.load(f) except (FileNotFoundError, json.JSONDecodeError): existing_log_data = [] - + existing_log_data.append(log_data) - + with self.mapping_file_path.open("w") as f: json.dump(existing_log_data, f, indent=4) - + return img diff --git a/tests/test_mapping_file.py b/tests/test_mapping_file.py index ca71dcb72a..ffec460a61 100644 --- a/tests/test_mapping_file.py +++ b/tests/test_mapping_file.py @@ -22,8 +22,8 @@ import numpy as np from parameterized import parameterized -from monai.data import Dataset, DataLoader -from monai.transforms import Compose, LoadImage, WriteFileMapping, SaveImage +from monai.data import DataLoader, Dataset +from monai.transforms import Compose, LoadImage, SaveImage, WriteFileMapping from monai.utils import optional_import nib, has_nib = optional_import("nibabel") @@ -78,40 +78,42 @@ def test_mapping_file(self, savepath_in_metadict): cause_exception = cm.exception.__cause__ self.assertIsInstance(cause_exception, KeyError) self.assertIn( - "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", str(cause_exception) + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", + str(cause_exception), ) def test_multiprocess_mapping_file(self): num_images = 50 - + single_mapping_file = os.path.join(self.temp_dir, "single_mapping.json") multi_mapping_file = os.path.join(self.temp_dir, "multi_mapping.json") - + data = [create_input_file(self.temp_dir, f"test_image_{i}") for i in range(num_images)] - + # single process single_transform = create_transform(self.temp_dir, single_mapping_file) single_dataset = Dataset(data=data, transform=single_transform) single_loader = DataLoader(single_dataset, batch_size=1, num_workers=0, shuffle=True) for _ in single_loader: pass - + # multiple processes multi_transform = create_transform(self.temp_dir, multi_mapping_file) multi_dataset = Dataset(data=data, transform=multi_transform) multi_loader = DataLoader(multi_dataset, batch_size=2, num_workers=2, shuffle=True) for _ in multi_loader: pass - + with open(single_mapping_file) as f: single_mapping_data = json.load(f) with open(multi_mapping_file) as f: multi_mapping_data = json.load(f) - - single_set = set((entry['input'], entry['output']) for entry in single_mapping_data) - multi_set = set((entry['input'], entry['output']) for entry in multi_mapping_data) - + + single_set = set((entry["input"], entry["output"]) for entry in single_mapping_data) + multi_set = set((entry["input"], entry["output"]) for entry in multi_mapping_data) + self.assertEqual(single_set, multi_set) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 130eaa1153298c543bcd153c4315a38a019be8ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 08:00:26 +0000 Subject: [PATCH 39/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_mapping_file.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_mapping_file.py b/tests/test_mapping_file.py index ffec460a61..8a9d03565c 100644 --- a/tests/test_mapping_file.py +++ b/tests/test_mapping_file.py @@ -13,10 +13,8 @@ import json import os -import random import shutil import tempfile -import time import unittest import numpy as np @@ -109,8 +107,8 @@ def test_multiprocess_mapping_file(self): with open(multi_mapping_file) as f: multi_mapping_data = json.load(f) - single_set = set((entry["input"], entry["output"]) for entry in single_mapping_data) - multi_set = set((entry["input"], entry["output"]) for entry in multi_mapping_data) + single_set = {(entry["input"], entry["output"]) for entry in single_mapping_data} + multi_set = {(entry["input"], entry["output"]) for entry in multi_mapping_data} self.assertEqual(single_set, multi_set) From b1475be40f954d6c2314cd60d47bc2611c0a6942 Mon Sep 17 00:00:00 2001 From: staydelight Date: Mon, 19 Aug 2024 17:43:04 +0800 Subject: [PATCH 40/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 5 +++-- tests/test_mapping_file.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index d63366f1a2..42b90523b3 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -536,7 +536,6 @@ class WriteFileMapping(Transform): def __init__(self, mapping_file_path: Path | str = "mapping.json"): self.mapping_file_path = Path(mapping_file_path) - self.lock = FileLock(str(self.mapping_file_path) + ".lock") def __call__(self, img: MetaTensor): """ @@ -552,7 +551,9 @@ def __call__(self, img: MetaTensor): output_path = img.meta[MetaKeys.SAVED_TO] log_data = {"input": input_path, "output": output_path} - with self.lock: + lock = FileLock(str(self.mapping_file_path) + ".lock") + + with lock: try: with self.mapping_file_path.open("r") as f: existing_log_data = json.load(f) diff --git a/tests/test_mapping_file.py b/tests/test_mapping_file.py index 8a9d03565c..97fa4312ed 100644 --- a/tests/test_mapping_file.py +++ b/tests/test_mapping_file.py @@ -98,7 +98,7 @@ def test_multiprocess_mapping_file(self): # multiple processes multi_transform = create_transform(self.temp_dir, multi_mapping_file) multi_dataset = Dataset(data=data, transform=multi_transform) - multi_loader = DataLoader(multi_dataset, batch_size=2, num_workers=2, shuffle=True) + multi_loader = DataLoader(multi_dataset, batch_size=4, num_workers=3, shuffle=True) for _ in multi_loader: pass From ca1515622fde962885bf8ab15f10e865dfd70ab8 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 27 Aug 2024 15:44:52 +0800 Subject: [PATCH 41/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/__init__.py | 10 ++- monai/transforms/io/dictionary.py | 31 +++++++- tests/test_mapping_filed.py | 118 ++++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 3 deletions(-) create mode 100644 tests/test_mapping_filed.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 69d4426c57..cf6f35dfe0 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -239,7 +239,15 @@ from .inverse import InvertibleTransform, TraceableTransform from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping -from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict +from .io.dictionary import ( + LoadImaged, + LoadImageD, + LoadImageDict, + SaveImaged, + SaveImageD, + SaveImageDict, + WriteFileMappingd, +) from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict from .lazy.functional import apply_pending diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 4da1d422ca..eb5178b3ba 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -17,16 +17,18 @@ from __future__ import annotations +from collections.abc import Hashable, Mapping, Sequence from pathlib import Path from typing import Callable import numpy as np +from filelock import FileLock import monai -from monai.config import DtypeLike, KeysCollection +from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor, PathLike from monai.data import image_writer from monai.data.image_reader import ImageReader -from monai.transforms.io.array import LoadImage, SaveImage +from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping from monai.transforms.transform import MapTransform, Transform from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep from monai.utils.enums import PostFix @@ -320,5 +322,30 @@ def __call__(self, data): return d +class WriteFileMappingd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + mapping_file_path: Path to the JSON file where the mappings will be saved. + Defaults to "mapping.json". + allow_missing_keys: don't raise exception if key is missing. + """ + + def __init__( + self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mapping = WriteFileMapping(mapping_file_path) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.mapping(d[key]) + return d + + LoadImageD = LoadImageDict = LoadImaged SaveImageD = SaveImageDict = SaveImaged diff --git a/tests/test_mapping_filed.py b/tests/test_mapping_filed.py new file mode 100644 index 0000000000..a9b4409d7c --- /dev/null +++ b/tests/test_mapping_filed.py @@ -0,0 +1,118 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +import shutil +import tempfile +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import DataLoader, Dataset, decollate_batch +from monai.inferers import sliding_window_inference +from monai.networks.nets import UNet +from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, SaveImaged, WriteFileMappingd +from monai.utils import optional_import + +nib, has_nib = optional_import("nibabel") + + +def create_input_file(temp_dir, name): + test_image = np.random.rand(128, 128, 128) + input_file = os.path.join(temp_dir, name + ".nii.gz") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file) + return input_file + + +TEST_CASE_1 = [["seg"], ["seg"]] +TEST_CASE_2 = [["seg"], ["image"]] +TEST_CASE_3 = [["image"], ["seg"]] +TEST_CASE_4 = [["image", "seg"], ["seg"]] +TEST_CASE_5 = [["seg"], ["image", "seg"]] + + +@unittest.skipUnless(has_nib, "nibabel required") +class TestWriteFileMappingd(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.output_dir = os.path.join(self.temp_dir, "output") + os.makedirs(self.output_dir) + self.mapping_file_path = os.path.join(self.temp_dir, "mapping.json") + + def tearDown(self): + shutil.rmtree(self.temp_dir) + if os.path.exists(self.mapping_file_path): + os.remove(self.mapping_file_path) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_mapping_filed(self, save_keys, write_keys): + + name = "test_image" + input_file = create_input_file(self.temp_dir, name) + output_file = os.path.join(self.output_dir, name, name + "_seg.nii.gz") + data = [{"image": input_file}] + + test_transforms = Compose([LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"])]) + + post_transforms = Compose( + [ + SaveImaged( + keys=save_keys, + meta_keys="image_meta_dict", + output_dir=self.output_dir, + output_postfix="seg", + savepath_in_metadict=True, + ), + WriteFileMappingd(keys=write_keys, mapping_file_path=self.mapping_file_path), + ] + ) + + dataset = Dataset(data=data, transform=test_transforms) + dataloader = DataLoader(dataset, batch_size=1) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = UNet(spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32), strides=(2,)).to(device) + model.eval() + + try: + with torch.no_grad(): + for batch_data in dataloader: + test_inputs = batch_data["image"].to(device) + roi_size = (64, 64, 64) + sw_batch_size = 2 + batch_data["seg"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model) + batch_data = [post_transforms(i) for i in decollate_batch(batch_data)] + + self.assertTrue(os.path.exists(self.mapping_file_path)) + + with open(self.mapping_file_path, "r") as f: + mapping_data = json.load(f) + + self.assertEqual(len(mapping_data), len(write_keys)) + for entry in mapping_data: + self.assertEqual(entry["input"], input_file) + self.assertEqual(entry["output"], output_file) + + except RuntimeError as cm: + cause_exception = cm.__cause__ + self.assertIsInstance(cause_exception, KeyError) + self.assertIn( + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", + str(cause_exception), + ) + + +if __name__ == "__main__": + unittest.main() From 3dc9f4931dec855580998a85c1071ad482112425 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 07:46:40 +0000 Subject: [PATCH 42/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/io/dictionary.py | 5 ++--- tests/test_mapping_filed.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index eb5178b3ba..807371583e 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -17,15 +17,14 @@ from __future__ import annotations -from collections.abc import Hashable, Mapping, Sequence +from collections.abc import Hashable, Mapping from pathlib import Path from typing import Callable import numpy as np -from filelock import FileLock import monai -from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor, PathLike +from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor from monai.data import image_writer from monai.data.image_reader import ImageReader from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping diff --git a/tests/test_mapping_filed.py b/tests/test_mapping_filed.py index a9b4409d7c..835750cb72 100644 --- a/tests/test_mapping_filed.py +++ b/tests/test_mapping_filed.py @@ -97,7 +97,7 @@ def test_mapping_filed(self, save_keys, write_keys): self.assertTrue(os.path.exists(self.mapping_file_path)) - with open(self.mapping_file_path, "r") as f: + with open(self.mapping_file_path) as f: mapping_data = json.load(f) self.assertEqual(len(mapping_data), len(write_keys)) From 3ea0df2682dfd63b3fa23ee858812785e19abb90 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 27 Aug 2024 16:47:33 +0800 Subject: [PATCH 43/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 11 +++++++---- monai/transforms/io/dictionary.py | 1 + 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 93d7c155d9..7836382646 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -537,18 +537,21 @@ class WriteFileMapping(Transform): def __init__(self, mapping_file_path: Path | str = "mapping.json"): self.mapping_file_path = Path(mapping_file_path) - def __call__(self, img: MetaTensor): + def __call__(self, img: MetaTensor | torch.Tensor | np.ndarray): """ Args: img (MetaTensor): The input image with metadata. """ - if MetaKeys.SAVED_TO not in img.meta: + if isinstance(img, MetaTensor): + meta_data = img.meta + + if MetaKeys.SAVED_TO not in meta_data: raise KeyError( "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True." ) - input_path = img.meta[Key.FILENAME_OR_OBJ] - output_path = img.meta[MetaKeys.SAVED_TO] + input_path = meta_data[Key.FILENAME_OR_OBJ] + output_path = meta_data[MetaKeys.SAVED_TO] log_data = {"input": input_path, "output": output_path} lock = FileLock(str(self.mapping_file_path) + ".lock") diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 807371583e..be1e78db8a 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -348,3 +348,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N LoadImageD = LoadImageDict = LoadImaged SaveImageD = SaveImageDict = SaveImaged +WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd From 8ad9808256454829af7318d6a5f20e0cb77fdee5 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 27 Aug 2024 17:02:54 +0800 Subject: [PATCH 44/52] fix-issue-7557 Signed-off-by: staydelight --- docs/source/transforms.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 637f0873f1..84f7cb267f 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -553,6 +553,12 @@ IO .. autoclass:: SaveImage :members: :special-members: __call__ + +`WriteFileMapping` +"""""""""""" +.. autoclass:: WriteFileMapping + :members: + :special-members: __call__ NVIDIA Tool Extension (NVTX) @@ -1641,6 +1647,12 @@ IO (Dict) .. autoclass:: SaveImaged :members: :special-members: __call__ + +`WriteFileMappingd` +"""""""""""" +.. autoclass:: WriteFileMappingd + :members: + :special-members: __call__ Post-processing (Dict) ^^^^^^^^^^^^^^^^^^^^^^ From 802e554513ce5b735b892796ce420b2f989ccfe9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 09:04:23 +0000 Subject: [PATCH 45/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/transforms.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 84f7cb267f..1a5b2a738e 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -553,7 +553,7 @@ IO .. autoclass:: SaveImage :members: :special-members: __call__ - + `WriteFileMapping` """""""""""" .. autoclass:: WriteFileMapping @@ -1647,7 +1647,7 @@ IO (Dict) .. autoclass:: SaveImaged :members: :special-members: __call__ - + `WriteFileMappingd` """""""""""" .. autoclass:: WriteFileMappingd From 60f5b79acec0a43eb77eeca6a30f880caa581ff3 Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 27 Aug 2024 17:16:40 +0800 Subject: [PATCH 46/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index cf6f35dfe0..f37016e63f 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -247,6 +247,8 @@ SaveImageD, SaveImageDict, WriteFileMappingd, + WriteFileMappingD, + WriteFileMappingDict, ) from .lazy.array import ApplyPending from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict From b7957b6bc713d6ae484f849e949c313974e06eef Mon Sep 17 00:00:00 2001 From: staydelight Date: Tue, 27 Aug 2024 17:49:47 +0800 Subject: [PATCH 47/52] fix-issue-7557 Signed-off-by: staydelight --- docs/source/transforms.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 1a5b2a738e..3e45d899ec 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -555,7 +555,7 @@ IO :special-members: __call__ `WriteFileMapping` -"""""""""""" +"""""""""""""""""" .. autoclass:: WriteFileMapping :members: :special-members: __call__ @@ -1649,7 +1649,7 @@ IO (Dict) :special-members: __call__ `WriteFileMappingd` -"""""""""""" +""""""""""""""""""" .. autoclass:: WriteFileMappingd :members: :special-members: __call__ From 773a218d0655670ab82ef9099c742c135d67d1ca Mon Sep 17 00:00:00 2001 From: staydelight Date: Wed, 28 Aug 2024 00:57:18 +0800 Subject: [PATCH 48/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 7836382646..cde0727dc0 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -537,10 +537,10 @@ class WriteFileMapping(Transform): def __init__(self, mapping_file_path: Path | str = "mapping.json"): self.mapping_file_path = Path(mapping_file_path) - def __call__(self, img: MetaTensor | torch.Tensor | np.ndarray): + def __call__(self, img: NdarrayOrTensor): """ Args: - img (MetaTensor): The input image with metadata. + img: The input image with metadata. """ if isinstance(img, MetaTensor): meta_data = img.meta From b28b184ce759cf5b6c425c2e1ef4921bc8c56ba8 Mon Sep 17 00:00:00 2001 From: staydelight Date: Wed, 28 Aug 2024 15:40:48 +0800 Subject: [PATCH 49/52] fix-issue-7557 Signed-off-by: staydelight --- monai/transforms/io/array.py | 31 ++++++++-------- tests/test_mapping_filed.py | 72 +++++++++++++++++++----------------- 2 files changed, 54 insertions(+), 49 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index cde0727dc0..4e71870fc9 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -58,7 +58,7 @@ nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") nrrd, _ = optional_import("nrrd") -FileLock, _ = optional_import("filelock", name="FileLock") +FileLock, has_filelock = optional_import("filelock", name="FileLock") __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"] @@ -554,18 +554,19 @@ def __call__(self, img: NdarrayOrTensor): output_path = meta_data[MetaKeys.SAVED_TO] log_data = {"input": input_path, "output": output_path} - lock = FileLock(str(self.mapping_file_path) + ".lock") - - with lock: - try: - with self.mapping_file_path.open("r") as f: - existing_log_data = json.load(f) - except (FileNotFoundError, json.JSONDecodeError): - existing_log_data = [] - - existing_log_data.append(log_data) - - with self.mapping_file_path.open("w") as f: - json.dump(existing_log_data, f, indent=4) - + if has_filelock: + with FileLock(str(self.mapping_file_path) + ".lock"): + self._write_to_file(log_data) + else: + self._write_to_file(log_data) return img + + def _write_to_file(self, log_data): + try: + with self.mapping_file_path.open("r") as f: + existing_log_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + existing_log_data = [] + existing_log_data.append(log_data) + with self.mapping_file_path.open("w") as f: + json.dump(existing_log_data, f, indent=4) diff --git a/tests/test_mapping_filed.py b/tests/test_mapping_filed.py index 835750cb72..d0f8bcf938 100644 --- a/tests/test_mapping_filed.py +++ b/tests/test_mapping_filed.py @@ -37,11 +37,11 @@ def create_input_file(temp_dir, name): return input_file -TEST_CASE_1 = [["seg"], ["seg"]] -TEST_CASE_2 = [["seg"], ["image"]] -TEST_CASE_3 = [["image"], ["seg"]] -TEST_CASE_4 = [["image", "seg"], ["seg"]] -TEST_CASE_5 = [["seg"], ["image", "seg"]] +# Test cases that should succeed +SUCCESS_CASES = [(["seg"], ["seg"]), (["image", "seg"], ["seg"])] + +# Test cases that should fail +FAILURE_CASES = [(["seg"], ["image"]), (["image"], ["seg"]), (["seg"], ["image", "seg"])] @unittest.skipUnless(has_nib, "nibabel required") @@ -57,9 +57,7 @@ def tearDown(self): if os.path.exists(self.mapping_file_path): os.remove(self.mapping_file_path) - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_mapping_filed(self, save_keys, write_keys): - + def run_test(self, save_keys, write_keys): name = "test_image" input_file = create_input_file(self.temp_dir, name) output_file = os.path.join(self.output_dir, name, name + "_seg.nii.gz") @@ -86,32 +84,38 @@ def test_mapping_filed(self, save_keys, write_keys): model = UNet(spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32), strides=(2,)).to(device) model.eval() - try: - with torch.no_grad(): - for batch_data in dataloader: - test_inputs = batch_data["image"].to(device) - roi_size = (64, 64, 64) - sw_batch_size = 2 - batch_data["seg"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model) - batch_data = [post_transforms(i) for i in decollate_batch(batch_data)] - - self.assertTrue(os.path.exists(self.mapping_file_path)) - - with open(self.mapping_file_path) as f: - mapping_data = json.load(f) - - self.assertEqual(len(mapping_data), len(write_keys)) - for entry in mapping_data: - self.assertEqual(entry["input"], input_file) - self.assertEqual(entry["output"], output_file) - - except RuntimeError as cm: - cause_exception = cm.__cause__ - self.assertIsInstance(cause_exception, KeyError) - self.assertIn( - "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", - str(cause_exception), - ) + with torch.no_grad(): + for batch_data in dataloader: + test_inputs = batch_data["image"].to(device) + roi_size = (64, 64, 64) + sw_batch_size = 2 + batch_data["seg"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model) + batch_data = [post_transforms(i) for i in decollate_batch(batch_data)] + + return input_file, output_file + + @parameterized.expand(SUCCESS_CASES) + def test_successful_mapping_filed(self, save_keys, write_keys): + input_file, output_file = self.run_test(save_keys, write_keys) + self.assertTrue(os.path.exists(self.mapping_file_path)) + with open(self.mapping_file_path) as f: + mapping_data = json.load(f) + self.assertEqual(len(mapping_data), len(write_keys)) + for entry in mapping_data: + self.assertEqual(entry["input"], input_file) + self.assertEqual(entry["output"], output_file) + + @parameterized.expand(FAILURE_CASES) + def test_failure_mapping_filed(self, save_keys, write_keys): + with self.assertRaises(RuntimeError) as cm: + self.run_test(save_keys, write_keys) + + cause_exception = cm.exception.__cause__ + self.assertIsInstance(cause_exception, KeyError) + self.assertIn( + "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.", + str(cause_exception), + ) if __name__ == "__main__": From 6f9e44036231e977a2d7d257ceead886d3cfe00d Mon Sep 17 00:00:00 2001 From: staydelight Date: Fri, 11 Oct 2024 15:41:58 +0800 Subject: [PATCH 50/52] fix-issue-6366 Signed-off-by: staydelight --- .github/workflows/cron.yml | 24 +- .github/workflows/pythonapp-gpu.yml | 4 +- .github/workflows/release.yml | 13 +- Dockerfile | 3 +- docs/requirements.txt | 1 + docs/source/config_syntax.md | 42 ++ docs/source/networks.rst | 5 - docs/source/transforms.rst | 36 ++ docs/source/utils.rst | 6 - monai/__init__.py | 45 +- monai/apps/vista3d/inferer.py | 2 +- monai/apps/vista3d/sampler.py | 29 +- monai/bundle/config_parser.py | 8 +- monai/bundle/scripts.py | 8 +- monai/bundle/utils.py | 36 +- monai/bundle/workflows.py | 7 +- monai/config/deviceconfig.py | 10 +- monai/data/image_reader.py | 2 +- monai/engines/evaluator.py | 4 +- monai/engines/trainer.py | 3 +- monai/engines/utils.py | 3 +- monai/engines/workflow.py | 3 +- monai/handlers/__init__.py | 3 +- monai/handlers/checkpoint_loader.py | 3 +- monai/handlers/checkpoint_saver.py | 3 +- monai/handlers/classification_saver.py | 2 +- monai/handlers/decollate_batch.py | 4 +- monai/handlers/earlystop_handler.py | 3 +- monai/handlers/garbage_collector.py | 3 +- monai/handlers/ignite_metric.py | 25 +- monai/handlers/logfile_handler.py | 3 +- monai/handlers/lr_schedule_handler.py | 3 +- monai/handlers/metric_logger.py | 3 +- monai/handlers/metrics_saver.py | 2 +- monai/handlers/mlflow_handler.py | 3 +- monai/handlers/nvtx_handlers.py | 3 +- monai/handlers/parameter_scheduler.py | 3 +- monai/handlers/postprocessing.py | 3 +- monai/handlers/probability_maps.py | 4 +- monai/handlers/smartcache_handler.py | 3 +- monai/handlers/stats_handler.py | 3 +- monai/handlers/tensorboard_handlers.py | 3 +- monai/handlers/trt_handler.py | 60 ++ monai/handlers/utils.py | 4 +- monai/handlers/validation_handler.py | 3 +- monai/losses/dice.py | 17 +- monai/metrics/generalized_dice.py | 125 ++-- monai/networks/__init__.py | 2 + monai/networks/blocks/patchembedding.py | 8 +- monai/networks/layers/filtering.py | 8 +- monai/networks/nets/hovernet.py | 3 +- monai/networks/nets/swin_unetr.py | 8 +- monai/networks/nets/unet.py | 3 - monai/networks/nets/unetr.py | 9 +- monai/networks/nets/vista3d.py | 21 +- monai/networks/nets/vit.py | 8 - monai/networks/nets/vitautoenc.py | 9 +- monai/networks/nets/voxelmorph.py | 5 - monai/networks/trt_compiler.py | 569 ++++++++++++++++++ monai/networks/utils.py | 266 ++++++-- monai/transforms/__init__.py | 12 + monai/transforms/adaptors.py | 5 - monai/transforms/intensity/array.py | 2 +- monai/transforms/spatial/array.py | 44 ++ monai/transforms/spatial/dictionary.py | 61 ++ monai/transforms/spatial/functional.py | 71 ++- monai/transforms/utility/array.py | 157 ++++- monai/transforms/utility/dictionary.py | 108 +++- monai/transforms/utils.py | 41 +- monai/utils/__init__.py | 7 +- monai/utils/aliases.py | 103 ---- monai/utils/dist.py | 2 +- monai/utils/enums.py | 64 +- monai/utils/jupyter_utils.py | 2 +- monai/utils/misc.py | 2 +- monai/utils/module.py | 24 - monai/utils/type_conversion.py | 8 + pyproject.toml | 1 + requirements-dev.txt | 2 + setup.cfg | 4 + tests/min_tests.py | 2 + tests/test_apply_transform_to_points.py | 81 +++ tests/test_apply_transform_to_pointsd.py | 185 ++++++ tests/test_bundle_download.py | 2 +- tests/test_compute_generalized_dice.py | 170 ++++-- tests/test_config_parser.py | 32 + tests/test_convert_box_points.py | 121 ++++ tests/test_data_stats.py | 41 +- tests/test_data_statsd.py | 54 +- tests/test_fastmri_reader.py | 3 +- tests/test_gdsdataset.py | 7 +- tests/test_handler_garbage_collector.py | 3 +- tests/test_nrrd_reader.py | 8 +- tests/test_rand_weighted_crop.py | 30 + .../test_scale_intensity_range_percentiles.py | 2 + tests/test_sure_loss.py | 2 +- tests/test_trt_compile.py | 140 +++++ 97 files changed, 2568 insertions(+), 539 deletions(-) create mode 100644 monai/handlers/trt_handler.py create mode 100644 monai/networks/trt_compiler.py delete mode 100644 monai/utils/aliases.py create mode 100644 tests/test_apply_transform_to_points.py create mode 100644 tests/test_apply_transform_to_pointsd.py create mode 100644 tests/test_convert_box_points.py create mode 100644 tests/test_trt_compile.py diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index cc113b0446..6732ab7256 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -13,24 +13,24 @@ jobs: strategy: matrix: environment: - - "PT191+CUDA113" - "PT110+CUDA113" - - "PT113+CUDA113" - - "PTLATEST+CUDA121" + - "PT113+CUDA118" + - "PT210+CUDA121" + - "PTLATEST+CUDA124" include: # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes - environment: PT110+CUDA113 pytorch: "torch==1.10.2 torchvision==0.11.3 --extra-index-url https://download.pytorch.org/whl/cu113" base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3 - - environment: PT113+CUDA113 - pytorch: "torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu113" - base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3 - - environment: PT113+CUDA122 + - environment: PT113+CUDA118 pytorch: "torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu121" - base: "nvcr.io/nvidia/pytorch:23.08-py3" # CUDA 12.2 + base: "nvcr.io/nvidia/pytorch:22.10-py3" # CUDA 11.8 + - environment: PT210+CUDA121 + pytorch: "pytorch==2.1.0 torchvision==0.16.0 --extra-index-url https://download.pytorch.org/whl/cu121" + base: "nvcr.io/nvidia/pytorch:23.08-py3" # CUDA 12.1 - environment: PTLATEST+CUDA124 pytorch: "-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121" - base: "nvcr.io/nvidia/pytorch:24.03-py3" # CUDA 12.4 + base: "nvcr.io/nvidia/pytorch:24.08-py3" # CUDA 12.4 container: image: ${{ matrix.base }} options: "--gpus all" @@ -80,7 +80,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:23.08", "pytorch:24.03"] + container: ["pytorch:23.08", "pytorch:24.08"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -129,7 +129,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' strategy: matrix: - container: ["pytorch:24.03"] + container: ["pytorch:24.08"] container: image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image options: "--gpus all" @@ -233,7 +233,7 @@ jobs: if: github.repository == 'Project-MONAI/MONAI' needs: cron-gpu # so that monai itself is verified first container: - image: nvcr.io/nvidia/pytorch:24.03-py3 # testing with the latest pytorch base image + image: nvcr.io/nvidia/pytorch:24.08-py3 # testing with the latest pytorch base image options: "--gpus all --ipc=host" runs-on: [self-hosted, linux, x64, integration] steps: diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index ead622b39c..70c3153076 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -44,9 +44,9 @@ jobs: pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error base: "nvcr.io/nvidia/pytorch:23.08-py3" - environment: PT210+CUDA121DOCKER - # 24.03: 2.3.0a0+40ec155e58.nv24.3 + # 24.08: 2.3.0a0+40ec155e58.nv24.3 pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error - base: "nvcr.io/nvidia/pytorch:24.03-py3" + base: "nvcr.io/nvidia/pytorch:24.08-py3" container: image: ${{ matrix.base }} options: --gpus all --env NVIDIA_DISABLE_REQUIRE=true # workaround for unsatisfied condition: cuda>=11.6 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a014a4ed1d..cb0e109bb7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -78,12 +78,13 @@ jobs: rm dist/monai*.tar.gz ls -al dist/ - - if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/') - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - password: ${{ secrets.TEST_PYPI }} - repository-url: https://test.pypi.org/legacy/ + # remove publishing to Test PyPI as it is moved to blossom + # - if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/') + # name: Publish to Test PyPI + # uses: pypa/gh-action-pypi-publish@release/v1 + # with: + # password: ${{ secrets.TEST_PYPI }} + # repository-url: https://test.pypi.org/legacy/ versioning: # compute versioning file from python setup.py diff --git a/Dockerfile b/Dockerfile index 8e255597d1..e45932c6bb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ # To build with a different base image # please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag. -ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.03-py3 +ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.08-py3 FROM ${PYTORCH_IMAGE} LABEL maintainer="monai.contact@gmail.com" @@ -56,4 +56,5 @@ RUN apt-get update \ && rm -rf /var/lib/apt/lists/* # append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations ENV PATH=${PATH}:/opt/tools +ENV POLYGRAPHY_AUTOINSTALL_DEPS=1 WORKDIR /opt/monai diff --git a/docs/requirements.txt b/docs/requirements.txt index ff94f7b6de..7307d8e5f9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -42,3 +42,4 @@ zarr huggingface_hub pyamg>=5.0.0 packaging +polygraphy diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md index c932879b5a..742841acca 100644 --- a/docs/source/config_syntax.md +++ b/docs/source/config_syntax.md @@ -16,6 +16,7 @@ Content: - [`$` to evaluate as Python expressions](#to-evaluate-as-python-expressions) - [`%` to textually replace configuration elements](#to-textually-replace-configuration-elements) - [`_target_` (`_disabled_`, `_desc_`, `_requires_`, `_mode_`) to instantiate a Python object](#instantiate-a-python-object) + - [`+` to alter semantics of merging config keys from multiple configuration files](#multiple-config-files) - [The command line interface](#the-command-line-interface) - [Recommendations](#recommendations) @@ -175,6 +176,47 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k - `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``, see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall). +## Multiple config files + +_Description:_ Multiple config files may be specified on the command line. +The content of those config files is being merged. When same keys are specifiled in more than one config file, +the value associated with the key is being overridden, in the order config files are specified. +If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with `+`. +The value types for the merged contents must match and be both of `dict` or both of `list` type. +`dict` values will be merged via update(), `list` values - concatenated via extend(). +Here's an example. In this case, "amp" value will be overridden by extra_config.json. +`imports` and `preprocessing#transforms` lists will be merged. An error would be thrown if the value type in `"+imports"` is not `list`: + +config.json: +```json +{ + "amp": "$True" + "imports": [ + "$import torch" + ], + "preprocessing": { + "_target_": "Compose", + "transforms": [ + "$@t1", + "$@t2" + ] + }, +} +``` + +extra_config.json: +```json +{ + "amp": "$False" + "+imports": [ + "$from monai.networks import trt_compile" + ], + "+preprocessing#transforms": [ + "$@t3" + ] +} +``` + ## The command line interface In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle. diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 1810fec49b..64a3a4c9d1 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -735,14 +735,9 @@ Nets .. autoclass:: VoxelMorphUNet :members: -.. autoclass:: voxelmorphunet - :members: - .. autoclass:: VoxelMorph :members: -.. autoclass:: voxelmorph - Utilities --------- .. automodule:: monai.networks.utils diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 3e45d899ec..41bb4ae79a 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -976,6 +976,18 @@ Spatial :members: :special-members: __call__ +`ConvertBoxToPoints` +"""""""""""""""""""" +.. autoclass:: ConvertBoxToPoints + :members: + :special-members: __call__ + +`ConvertPointsToBoxes` +"""""""""""""""""""""" +.. autoclass:: ConvertPointsToBoxes + :members: + :special-members: __call__ + Smooth Field ^^^^^^^^^^^^ @@ -1222,6 +1234,12 @@ Utility :members: :special-members: __call__ +`ApplyTransformToPoints` +"""""""""""""""""""""""" +.. autoclass:: ApplyTransformToPoints + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -1973,6 +1991,18 @@ Spatial (Dict) :members: :special-members: __call__ +`ConvertBoxToPointsd` +""""""""""""""""""""" +.. autoclass:: ConvertBoxToPointsd + :members: + :special-members: __call__ + +`ConvertPointsToBoxesd` +""""""""""""""""""""""" +.. autoclass:: ConvertPointsToBoxesd + :members: + :special-members: __call__ + Smooth Field (Dict) ^^^^^^^^^^^^^^^^^^^ @@ -2277,6 +2307,12 @@ Utility (Dict) :members: :special-members: __call__ +`ApplyTransformToPointsd` +""""""""""""""""""""""""" +.. autoclass:: ApplyTransformToPointsd + :members: + :special-members: __call__ + MetaTensor ^^^^^^^^^^ diff --git a/docs/source/utils.rst b/docs/source/utils.rst index fef671e1f8..ae3b476c3e 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -17,12 +17,6 @@ Module utils :members: -Aliases -------- -.. automodule:: monai.utils.aliases - :members: - - Misc ---- .. automodule:: monai.utils.misc diff --git a/monai/__init__.py b/monai/__init__.py index cb0ccd36f8..f6fc8b0646 100644 --- a/monai/__init__.py +++ b/monai/__init__.py @@ -13,9 +13,51 @@ import os import sys - +import logging +import warnings from ._version import get_versions + +old_showwarning = warnings.showwarning + + +def custom_warning_handler(message, category, filename, lineno, file=None, line=None): + ignore_files = ["ignite/handlers/checkpoint", "modelopt/torch/quantization/tensor_quant"] + if any(ignore in filename for ignore in ignore_files): + return + old_showwarning(message, category, filename, lineno, file, line) + + +class DeprecatedTypesWarningFilter(logging.Filter): + def filter(self, record): + message_bodies_to_ignore = [ + "np.bool8", + "np.object0", + "np.int0", + "np.uint0", + "np.void0", + "np.str0", + "np.bytes0", + "@validator", + "@root_validator", + "class-based `config`", + "pkg_resources", + "Implicitly cleaning up", + ] + for message in message_bodies_to_ignore: + if message in record.getMessage(): + return False + return True + + +# workaround for https://github.com/Project-MONAI/MONAI/issues/8060 +# TODO: remove this workaround after upstream fixed the warning +# Set the custom warning handler to filter warning +warnings.showwarning = custom_warning_handler +# Get the logger for warnings and add the filter to the logger +logging.getLogger("py.warnings").addFilter(DeprecatedTypesWarningFilter()) + + PY_REQUIRED_MAJOR = 3 PY_REQUIRED_MINOR = 9 @@ -37,6 +79,7 @@ category=RuntimeWarning, ) + from .utils.module import load_submodules # noqa: E402 # handlers_* have some external decorators the users may not have installed diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py index 709f81f624..8f622ef6cd 100644 --- a/monai/apps/vista3d/inferer.py +++ b/monai/apps/vista3d/inferer.py @@ -100,7 +100,7 @@ def point_based_window_inferer( point_labels=point_labels, class_vector=class_vector, prompt_class=prompt_class, - patch_coords=unravel_slice, + patch_coords=[unravel_slice], prev_mask=prev_mask, **kwargs, ) diff --git a/monai/apps/vista3d/sampler.py b/monai/apps/vista3d/sampler.py index b7aeb89a2e..17b2d34911 100644 --- a/monai/apps/vista3d/sampler.py +++ b/monai/apps/vista3d/sampler.py @@ -20,8 +20,6 @@ import torch from torch import Tensor -__all__ = ["sample_prompt_pairs"] - ENABLE_SPECIAL = True SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128) MERGE_LIST = { @@ -30,6 +28,8 @@ 132: [57], # overlap with trachea merge into airway } +__all__ = ["sample_prompt_pairs"] + def _get_point_label(id: int) -> tuple[int, int]: if id in SPECIAL_INDEX and ENABLE_SPECIAL: @@ -66,22 +66,29 @@ def sample_prompt_pairs( max_backprompt: int, max number of prompt from background. max_point: maximum number of points for each object. include_background: if include 0 into training prompt. If included, background 0 is treated - the same as foreground. Always be False for multi-partial-dataset training. If needed, - can be true for finetuning specific dataset, . + the same as foreground and points will be sampled. Can be true only if user want to segment + background 0 with point clicks, otherwise always be false. drop_label_prob: probability to drop label prompt. drop_point_prob: probability to drop point prompt. point_sampler: sampler to augment masks with supervoxel. point_sampler_kwargs: arguments for point_sampler. Returns: - label_prompt: [B, 1]. The classes used for training automatic segmentation. - point: [B, N, 3]. The corresponding points for each class. - Note that background label prompt requires matching point as well ([0,0,0] is used). - point_label: [B, N]. The corresponding point labels for each point (negative or positive). - -1 is used for padding the background label prompt and will be ignored. - prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss. - label_prompt can be None, and prompt_class is used to identify point classes. + tuple: + - label_prompt (Tensor | None): Tensor of shape [B, 1] containing the classes used for + training automatic segmentation. + - point (Tensor | None): Tensor of shape [B, N, 3] representing the corresponding points + for each class. Note that background label prompts require matching points as well + (e.g., [0, 0, 0] is used). + - point_label (Tensor | None): Tensor of shape [B, N] representing the corresponding point + labels for each point (negative or positive). -1 is used for padding the background + label prompt and will be ignored. + - prompt_class (Tensor | None): Tensor of shape [B, 1], exactly the same as label_prompt + for label indexing during training. If label_prompt is None, prompt_class is used to + identify point classes. + """ + # class label number if not labels.shape[0] == 1: raise ValueError("only support batch size 1") diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index a2ffeedc92..1d9920a230 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -20,7 +20,7 @@ from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.reference_resolver import ReferenceResolver -from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY +from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY, merge_kv from monai.config import PathLike from monai.utils import ensure_tuple, look_up_option, optional_import from monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates @@ -423,8 +423,10 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs if isinstance(files, str) and not Path(files).is_file() and "," in files: files = files.split(",") for i in ensure_tuple(files): - for k, v in (cls.load_config_file(i, **kwargs)).items(): - parser[k] = v + config_dict = cls.load_config_file(i, **kwargs) + for k, v in config_dict.items(): + merge_kv(parser, k, v) + return parser.get() # type: ignore @classmethod diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 142a366669..4251da0b6f 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -32,9 +32,9 @@ from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser -from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA +from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow -from monai.config import IgniteInfo, PathLike +from monai.config import PathLike from monai.data import load_net_with_metadata, save_net_with_metadata from monai.networks import ( convert_to_onnx, @@ -45,6 +45,7 @@ save_state, ) from monai.utils import ( + IgniteInfo, check_parent_dir, deprecated_arg, ensure_tuple, @@ -105,7 +106,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw if isinstance(v, dict) and isinstance(args_.get(k), dict): args_[k] = update_kwargs(args_[k], ignore_none, **v) else: - args_[k] = v + merge_kv(args_, k, v) return args_ @@ -255,6 +256,7 @@ def _download_from_ngc_private( else: raise ValueError("NGC API requires requests package. Please install it.") + os.makedirs(download_path, exist_ok=True) zip_path = download_path / f"{filename}_v{version}.zip" with open(zip_path, "wb") as f: f.write(response.content) diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index 50d2608f4c..53d619f234 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -13,6 +13,7 @@ import json import os +import warnings import zipfile from typing import Any @@ -21,12 +22,21 @@ yaml, _ = optional_import("yaml") -__all__ = ["ID_REF_KEY", "ID_SEP_KEY", "EXPR_KEY", "MACRO_KEY", "DEFAULT_MLFLOW_SETTINGS", "DEFAULT_EXP_MGMT_SETTINGS"] +__all__ = [ + "ID_REF_KEY", + "ID_SEP_KEY", + "EXPR_KEY", + "MACRO_KEY", + "MERGE_KEY", + "DEFAULT_MLFLOW_SETTINGS", + "DEFAULT_EXP_MGMT_SETTINGS", +] ID_REF_KEY = "@" # start of a reference to a ConfigItem ID_SEP_KEY = "::" # separator for the ID of a ConfigItem EXPR_KEY = "$" # start of a ConfigExpression MACRO_KEY = "%" # start of a macro of a config +MERGE_KEY = "+" # prefix indicating merge instead of override in case of multiple configs. _conf_values = get_config_values() @@ -233,3 +243,27 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any parser.read_config(f=cdata) return parser + + +def merge_kv(args: dict | Any, k: str, v: Any) -> None: + """ + Update the `args` dict-like object with the key/value pair `k` and `v`. + """ + if k.startswith(MERGE_KEY): + """ + Both values associated with `+`-prefixed key pair must be of `dict` or `list` type. + `dict` values will be merged, `list` values - concatenated. + """ + id = k[1:] + if id in args: + if isinstance(v, dict) and isinstance(args[id], dict): + args[id].update(v) + elif isinstance(v, list) and isinstance(args[id], list): + args[id].extend(v) + else: + raise ValueError(ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}.")) + else: + warnings.warn(f"Can't merge entry ['{k}'], '{id}' is not in target dict - copying instead.") + args[id] = v + else: + args[k] = v diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 11c9bf0562..d728d7d930 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -26,7 +26,7 @@ from monai.bundle.properties import InferProperties, MetaProperties, TrainProperties from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY from monai.config import PathLike -from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg, deprecated_arg_default, ensure_tuple +from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg, ensure_tuple __all__ = ["BundleWorkflow", "ConfigWorkflow"] @@ -43,7 +43,7 @@ class BundleWorkflow(ABC): workflow_type: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. - default to `None` for common workflow. + default to `train` for train workflow. workflow: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. @@ -274,7 +274,6 @@ class ConfigWorkflow(BundleWorkflow): new_name="workflow_type", msg_suffix="please use `workflow_type` instead.", ) - @deprecated_arg_default("workflow_type", None, "train", since="1.2", replaced="1.4") def __init__( self, config_file: str | Sequence[str], @@ -284,7 +283,7 @@ def __init__( run_id: str = "run", final_id: str = "finalize", tracking: str | dict | None = None, - workflow_type: str | None = None, + workflow_type: str | None = "train", workflow: str | None = None, properties_path: PathLike | None = None, **override: Any, diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index a4580c741b..05842245ce 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -23,6 +23,8 @@ import torch import monai +from monai.utils.deprecate_utils import deprecated +from monai.utils.enums import IgniteInfo as _IgniteInfo from monai.utils.module import OptionalImportError, get_package_version, optional_import try: @@ -261,13 +263,11 @@ def print_debug_info(file: TextIO = sys.stdout) -> None: print_gpu_info(file) +@deprecated(since="1.4.0", removed="1.6.0", msg_suffix="Please use `monai.utils.enums.IgniteInfo` instead.") class IgniteInfo: - """ - Config information of the PyTorch ignite package. - - """ + """Deprecated Import of IgniteInfo enum, which was moved to `monai.utils.enums.IgniteInfo`.""" - OPT_IMPORT_VERSION = "0.4.4" + OPT_IMPORT_VERSION = _IgniteInfo.OPT_IMPORT_VERSION if __name__ == "__main__": diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index f5e199e2a3..b4ae562911 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -1359,7 +1359,7 @@ def _get_affine(self, header: dict) -> np.ndarray: x, y = direction.shape affine_diam = min(x, y) + 1 affine: np.ndarray = np.eye(affine_diam) - affine[:x, :y] = direction + affine[:x, :y] = direction.T affine[: (affine_diam - 1), -1] = origin # len origin is always affine_diam - 1 return affine diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 2c8dfe6b85..523c3dcbf6 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -17,14 +17,14 @@ import torch from torch.utils.data import DataLoader -from monai.config import IgniteInfo, KeysCollection +from monai.config import KeysCollection from monai.data import MetaTensor from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.networks.utils import eval_mode, train_mode from monai.transforms import Transform -from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import +from monai.utils import ForwardMode, IgniteInfo, ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys from monai.utils.module import look_up_option, pytorch_after diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index c1364fe015..bbcc9c880b 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -18,13 +18,12 @@ from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader -from monai.config import IgniteInfo from monai.data import MetaTensor from monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, min_version, optional_import +from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, IgniteInfo, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from monai.utils.enums import EngineStatsKeys as ESKeys from monai.utils.module import pytorch_after diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 5339d6965a..11a0000989 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -18,9 +18,8 @@ import torch import torch.nn as nn -from monai.config import IgniteInfo from monai.transforms import apply_transform -from monai.utils import ensure_tuple, min_version, optional_import +from monai.utils import IgniteInfo, ensure_tuple, min_version, optional_import from monai.utils.enums import CommonKeys, GanKeys if TYPE_CHECKING: diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 30622c2b93..3629659db1 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -20,10 +20,9 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from monai.config import IgniteInfo from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.transforms import Decollated -from monai.utils import ensure_tuple, is_scalar, min_version, optional_import +from monai.utils import IgniteInfo, ensure_tuple, is_scalar, min_version, optional_import from .utils import engine_apply_transform diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 641f9aae7d..c1fa448f25 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -20,7 +20,7 @@ from .earlystop_handler import EarlyStopHandler from .garbage_collector import GarbageCollector from .hausdorff_distance import HausdorffDistance -from .ignite_metric import IgniteMetric, IgniteMetricHandler +from .ignite_metric import IgniteMetricHandler from .logfile_handler import LogfileHandler from .lr_schedule_handler import LrScheduleHandler from .mean_dice import MeanDice @@ -40,5 +40,6 @@ from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler +from .trt_handler import TrtHandler from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports from .validation_handler import ValidationHandler diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 9a867534a3..f48968ecfd 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -17,9 +17,8 @@ import torch -from monai.config import IgniteInfo from monai.networks.utils import copy_model_state -from monai.utils import min_version, optional_import +from monai.utils import IgniteInfo, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index 0651c6ff33..2a3a467570 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -17,8 +17,7 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from monai.config import IgniteInfo -from monai.utils import is_scalar, min_version, optional_import +from monai.utils import IgniteInfo, is_scalar, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 831808f4fb..ffcfe3c1fb 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -18,8 +18,8 @@ import torch -from monai.config import IgniteInfo from monai.data import CSVSaver, decollate_batch +from monai.utils import IgniteInfo from monai.utils import ImageMetaKey as Key from monai.utils import evenly_divisible_all_gather, min_version, optional_import, string_list_all_gather diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py index ac3aa94145..81415bd56e 100644 --- a/monai/handlers/decollate_batch.py +++ b/monai/handlers/decollate_batch.py @@ -13,10 +13,10 @@ from typing import TYPE_CHECKING -from monai.config import IgniteInfo, KeysCollection +from monai.config import KeysCollection from monai.engines.utils import IterationEvents from monai.transforms import Decollated -from monai.utils import min_version, optional_import +from monai.utils import IgniteInfo, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") if TYPE_CHECKING: diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py index 93334bf5c0..0562335192 100644 --- a/monai/handlers/earlystop_handler.py +++ b/monai/handlers/earlystop_handler.py @@ -14,8 +14,7 @@ from collections.abc import Callable from typing import TYPE_CHECKING -from monai.config import IgniteInfo -from monai.utils import min_version, optional_import +from monai.utils import IgniteInfo, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") EarlyStopping, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EarlyStopping") diff --git a/monai/handlers/garbage_collector.py b/monai/handlers/garbage_collector.py index 3d7e948364..586fa10d33 100644 --- a/monai/handlers/garbage_collector.py +++ b/monai/handlers/garbage_collector.py @@ -14,8 +14,7 @@ import gc from typing import TYPE_CHECKING -from monai.config import IgniteInfo -from monai.utils import min_version, optional_import +from monai.utils import IgniteInfo, min_version, optional_import if TYPE_CHECKING: from ignite.engine import Engine, Events diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index 021154d705..44a5634c42 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -18,9 +18,8 @@ import torch from torch.nn.modules.loss import _Loss -from monai.config import IgniteInfo from monai.metrics import CumulativeIterationMetric, LossMetric -from monai.utils import MetricReduction, deprecated, min_version, optional_import +from monai.utils import IgniteInfo, MetricReduction, min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") @@ -153,25 +152,3 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override] self._name = name if self.save_details and not hasattr(engine.state, "metric_details"): engine.state.metric_details = {} # type: ignore - - -@deprecated(since="1.2", removed="1.4", msg_suffix="Use IgniteMetricHandler instead of IgniteMetric.") -class IgniteMetric(IgniteMetricHandler): - - def __init__( - self, - metric_fn: CumulativeIterationMetric | None = None, - loss_fn: _Loss | None = None, - output_transform: Callable = lambda x: x, - save_details: bool = True, - reduction: MetricReduction | str = MetricReduction.MEAN, - get_not_nans: bool = False, - ) -> None: - super().__init__( - metric_fn=metric_fn, - loss_fn=loss_fn, - output_transform=output_transform, - save_details=save_details, - reduction=reduction, - get_not_nans=get_not_nans, - ) diff --git a/monai/handlers/logfile_handler.py b/monai/handlers/logfile_handler.py index df6ebd34a7..0c44ae47f4 100644 --- a/monai/handlers/logfile_handler.py +++ b/monai/handlers/logfile_handler.py @@ -15,8 +15,7 @@ import os from typing import TYPE_CHECKING -from monai.config import IgniteInfo -from monai.utils import min_version, optional_import +from monai.utils import IgniteInfo, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") if TYPE_CHECKING: diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py index a79722517d..8d90992a84 100644 --- a/monai/handlers/lr_schedule_handler.py +++ b/monai/handlers/lr_schedule_handler.py @@ -17,8 +17,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler -from monai.config import IgniteInfo -from monai.utils import ensure_tuple, min_version, optional_import +from monai.utils import IgniteInfo, ensure_tuple, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") if TYPE_CHECKING: diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index d59205a021..62cdee6509 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -17,8 +17,7 @@ from threading import RLock from typing import TYPE_CHECKING, Any -from monai.config import IgniteInfo -from monai.utils import min_version, optional_import +from monai.utils import IgniteInfo, min_version, optional_import from monai.utils.enums import CommonKeys Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index 88a0926b91..6175b1242a 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -14,9 +14,9 @@ from collections.abc import Callable, Sequence from typing import TYPE_CHECKING -from monai.config import IgniteInfo from monai.data import decollate_batch from monai.handlers.utils import write_metrics_reports +from monai.utils import IgniteInfo from monai.utils import ImageMetaKey as Key from monai.utils import ensure_tuple, min_version, optional_import, string_list_all_gather diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index 6d19579d9e..c7e293ea7d 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -22,8 +22,7 @@ from torch.utils.data import Dataset from monai.apps.utils import get_logger -from monai.config import IgniteInfo -from monai.utils import CommonKeys, ensure_tuple, min_version, optional_import +from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.") diff --git a/monai/handlers/nvtx_handlers.py b/monai/handlers/nvtx_handlers.py index 38eef6f05b..bd22af0db8 100644 --- a/monai/handlers/nvtx_handlers.py +++ b/monai/handlers/nvtx_handlers.py @@ -16,8 +16,7 @@ from typing import TYPE_CHECKING -from monai.config import IgniteInfo -from monai.utils import ensure_tuple, min_version, optional_import +from monai.utils import IgniteInfo, ensure_tuple, min_version, optional_import _nvtx, _ = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?") if TYPE_CHECKING: diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py index d12e6e072c..1ce6193b6d 100644 --- a/monai/handlers/parameter_scheduler.py +++ b/monai/handlers/parameter_scheduler.py @@ -16,8 +16,7 @@ from collections.abc import Callable from typing import TYPE_CHECKING -from monai.config import IgniteInfo -from monai.utils import min_version, optional_import +from monai.utils import IgniteInfo, min_version, optional_import if TYPE_CHECKING: from ignite.engine import Engine, Events diff --git a/monai/handlers/postprocessing.py b/monai/handlers/postprocessing.py index c698c84338..541b5924d1 100644 --- a/monai/handlers/postprocessing.py +++ b/monai/handlers/postprocessing.py @@ -14,9 +14,8 @@ from collections.abc import Callable from typing import TYPE_CHECKING -from monai.config import IgniteInfo from monai.engines.utils import IterationEvents, engine_apply_transform -from monai.utils import min_version, optional_import +from monai.utils import IgniteInfo, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") if TYPE_CHECKING: diff --git a/monai/handlers/probability_maps.py b/monai/handlers/probability_maps.py index 8a60fcc983..e21bd199f8 100644 --- a/monai/handlers/probability_maps.py +++ b/monai/handlers/probability_maps.py @@ -17,10 +17,10 @@ import numpy as np -from monai.config import DtypeLike, IgniteInfo +from monai.config import DtypeLike from monai.data.folder_layout import FolderLayout from monai.utils import ProbMapKeys, min_version, optional_import -from monai.utils.enums import CommonKeys +from monai.utils.enums import CommonKeys, IgniteInfo Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") if TYPE_CHECKING: diff --git a/monai/handlers/smartcache_handler.py b/monai/handlers/smartcache_handler.py index ee043635db..e07e98e541 100644 --- a/monai/handlers/smartcache_handler.py +++ b/monai/handlers/smartcache_handler.py @@ -13,9 +13,8 @@ from typing import TYPE_CHECKING -from monai.config import IgniteInfo from monai.data import SmartCacheDataset -from monai.utils import min_version, optional_import +from monai.utils import IgniteInfo, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") if TYPE_CHECKING: diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index c49fcda819..ab36d19bd1 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -19,8 +19,7 @@ import torch from monai.apps import get_logger -from monai.config import IgniteInfo -from monai.utils import is_scalar, min_version, optional_import +from monai.utils import IgniteInfo, is_scalar, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") if TYPE_CHECKING: diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 7b7e3968fb..44a03710de 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -18,8 +18,7 @@ import numpy as np import torch -from monai.config import IgniteInfo -from monai.utils import is_scalar, min_version, optional_import +from monai.utils import IgniteInfo, is_scalar, min_version, optional_import from monai.visualize import plot_2d_or_3d_image Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") diff --git a/monai/handlers/trt_handler.py b/monai/handlers/trt_handler.py new file mode 100644 index 0000000000..45e2669f70 --- /dev/null +++ b/monai/handlers/trt_handler.py @@ -0,0 +1,60 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from monai.networks import trt_compile +from monai.utils import IgniteInfo, min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + + +class TrtHandler: + """ + TrtHandler acts as an Ignite handler to apply TRT acceleration to the model. + Usage example:: + handler = TrtHandler(model=model, base_path="/test/checkpoint.pt", args={"precision": "fp16"}) + handler.attach(engine) + engine.run() + """ + + def __init__(self, model, base_path, args=None, submodule=None): + """ + Args: + base_path: TRT path basename. TRT plan(s) saved to "base_path[.submodule].plan" + args: passed to trt_compile(). See trt_compile() for details. + submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder' + """ + self.model = model + self.base_path = base_path + self.args = args + self.submodule = submodule + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + self.logger = engine.logger + engine.add_event_handler(Events.STARTED, self) + + def __call__(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + trt_compile(self.model, self.base_path, args=self.args, submodule=self.submodule, logger=self.logger) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 0cd31b89c2..b6771f2dcc 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -19,8 +19,8 @@ import numpy as np import torch -from monai.config import IgniteInfo, KeysCollection, PathLike -from monai.utils import ensure_tuple, look_up_option, min_version, optional_import +from monai.config import KeysCollection, PathLike +from monai.utils import IgniteInfo, ensure_tuple, look_up_option, min_version, optional_import idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") if TYPE_CHECKING: diff --git a/monai/handlers/validation_handler.py b/monai/handlers/validation_handler.py index 89c7715f42..38dd511aa4 100644 --- a/monai/handlers/validation_handler.py +++ b/monai/handlers/validation_handler.py @@ -13,9 +13,8 @@ from typing import TYPE_CHECKING -from monai.config import IgniteInfo from monai.engines.evaluator import Evaluator -from monai.utils import min_version, optional_import +from monai.utils import IgniteInfo, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") if TYPE_CHECKING: diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 44cde41e5d..3f02fae6b8 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -24,7 +24,7 @@ from monai.losses.focal_loss import FocalLoss from monai.losses.spatial_mask import MaskedLoss from monai.networks import one_hot -from monai.utils import DiceCEReduction, LossReduction, Weight, deprecated_arg, look_up_option, pytorch_after +from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after class DiceLoss(_Loss): @@ -646,9 +646,6 @@ class DiceCELoss(_Loss): """ - @deprecated_arg( - "ce_weight", since="1.2", removed="1.4", new_name="weight", msg_suffix="please use `weight` instead." - ) def __init__( self, include_background: bool = True, @@ -662,7 +659,6 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, - ce_weight: torch.Tensor | None = None, weight: torch.Tensor | None = None, lambda_dice: float = 1.0, lambda_ce: float = 1.0, @@ -712,7 +708,6 @@ def __init__( """ super().__init__() reduction = look_up_option(reduction, DiceCEReduction).value - weight = ce_weight if ce_weight is not None else weight dice_weight: torch.Tensor | None if weight is not None and not include_background: dice_weight = weight[1:] @@ -825,9 +820,6 @@ class DiceFocalLoss(_Loss): """ - @deprecated_arg( - "focal_weight", since="1.2", removed="1.4", new_name="weight", msg_suffix="please use `weight` instead." - ) def __init__( self, include_background: bool = True, @@ -842,7 +834,6 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, gamma: float = 2.0, - focal_weight: Sequence[float] | float | int | torch.Tensor | None = None, weight: Sequence[float] | float | int | torch.Tensor | None = None, lambda_dice: float = 1.0, lambda_focal: float = 1.0, @@ -885,7 +876,6 @@ def __init__( [0, 1]. Defaults to None. """ super().__init__() - weight = focal_weight if focal_weight is not None else weight self.dice = DiceLoss( include_background=include_background, to_onehot_y=False, @@ -994,9 +984,6 @@ class GeneralizedDiceFocalLoss(_Loss): ValueError: if either `lambda_gdl` or `lambda_focal` is less than 0. """ - @deprecated_arg( - "focal_weight", since="1.2", removed="1.4", new_name="weight", msg_suffix="please use `weight` instead." - ) def __init__( self, include_background: bool = True, @@ -1010,7 +997,6 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, gamma: float = 2.0, - focal_weight: Sequence[float] | float | int | torch.Tensor | None = None, weight: Sequence[float] | float | int | torch.Tensor | None = None, lambda_gdl: float = 1.0, lambda_focal: float = 1.0, @@ -1028,7 +1014,6 @@ def __init__( smooth_dr=smooth_dr, batch=batch, ) - weight = focal_weight if focal_weight is not None else weight self.focal = FocalLoss( include_background=include_background, to_onehot_y=to_onehot_y, diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index e56bd46592..516021949b 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -14,34 +14,47 @@ import torch from monai.metrics.utils import do_metric_reduction, ignore_background -from monai.utils import MetricReduction, Weight, look_up_option +from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option from .metric import CumulativeIterationMetric class GeneralizedDiceScore(CumulativeIterationMetric): - """Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in: + """ + Compute the Generalized Dice Score metric between tensors. + This metric is the complement of the Generalized Dice Loss defined in: Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning - loss function for highly unbalanced segmentations. DLMIA 2017. + loss function for highly unbalanced segmentations. DLMIA 2017. - The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first - or batch-first tensors, i.e., CHW[D] or BCHW[D]. + The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D]. Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. Args: - include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the + include_background: Whether to include the background class (assumed to be in channel 0) in the score computation. Defaults to True. - reduction (str, optional): define mode of reduction to the metrics. Available reduction modes: - {``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction. - weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform + reduction: Define mode of reduction to the metrics. Available reduction modes: + {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. Raises: - ValueError: when the `weight_type` is not one of {``"none"``, ``"mean"``, ``"sum"``}. + ValueError: When the `reduction` is not one of MetricReduction enum. """ + @deprecated_arg_default( + "reduction", + old_default=MetricReduction.MEAN_BATCH, + new_default=MetricReduction.MEAN, + since="1.4.0", + replaced="1.5.0", + msg_suffix=( + "Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, " + "If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'." + ), + ) def __init__( self, include_background: bool = True, @@ -50,79 +63,90 @@ def __init__( ) -> None: super().__init__() self.include_background = include_background - reduction_options = [ - "none", - "mean_batch", - "sum_batch", - MetricReduction.NONE, - MetricReduction.MEAN_BATCH, - MetricReduction.SUM_BATCH, - ] - self.reduction = reduction - if self.reduction not in reduction_options: - raise ValueError(f"reduction must be one of {reduction_options}") + self.reduction = look_up_option(reduction, MetricReduction) self.weight_type = look_up_option(weight_type, Weight) + self.sum_over_classes = self.reduction in { + MetricReduction.SUM, + MetricReduction.MEAN, + MetricReduction.MEAN_CHANNEL, + MetricReduction.SUM_CHANNEL, + } def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] - """Computes the Generalized Dice Score and returns a tensor with its per image values. + """ + Computes the Generalized Dice Score and returns a tensor with its per image values. Args: - y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, + y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. - y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. + y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. + + Returns: + torch.Tensor: Generalized Dice Score averaged across batch and class Raises: - ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape. + ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape. """ return compute_generalized_dice( - y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type + y_pred=y_pred, + y=y, + include_background=self.include_background, + weight_type=self.weight_type, + sum_over_classes=self.sum_over_classes, ) + @deprecated_arg( + "reduction", + since="1.3.3", + removed="1.7.0", + msg_suffix="Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute", + ) def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor: """ Execute reduction logic for the output of `compute_generalized_dice`. - Args: - reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics. - Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}. - Defaults to ``"mean"``. If "none", will not do reduction. + Returns: + torch.Tensor: Aggregated metric value. + + Raises: + ValueError: If the data to aggregate is not a PyTorch Tensor. """ data = self.get_buffer() if not isinstance(data, torch.Tensor): raise ValueError("The data to aggregate must be a PyTorch Tensor.") - # Validate reduction argument if specified - if reduction is not None: - reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"] - if reduction not in reduction_options: - raise ValueError(f"reduction must be one of {reduction_options}") - # Do metric reduction and return - f, _ = do_metric_reduction(data, reduction or self.reduction) + f, _ = do_metric_reduction(data, self.reduction) return f def compute_generalized_dice( - y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE + y_pred: torch.Tensor, + y: torch.Tensor, + include_background: bool = True, + weight_type: Weight | str = Weight.SQUARE, + sum_over_classes: bool = False, ) -> torch.Tensor: - """Computes the Generalized Dice Score and returns a tensor with its per image values. + """ + Computes the Generalized Dice Score and returns a tensor with its per image values. Args: - y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format + y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. - y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`. - include_background (bool, optional): whether to include score computation on the first channel of the + y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`. + include_background: Whether to include score computation on the first channel of the predicted output. Defaults to True. weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. + sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation. Returns: - torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. + torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. Raises: - ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions, + ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions, or `y_pred` and `y` don't have the same shape. """ # Ensure tensors have at least 3 dimensions and have the same shape @@ -158,16 +182,21 @@ def compute_generalized_dice( b[infs] = 0 b[infs] = torch.max(b) - # Compute the weighted numerator and denominator, summing along the class axis - numer = 2.0 * (intersection * w).sum(dim=1) - denom = (denominator * w).sum(dim=1) + # Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True + if sum_over_classes: + numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True) + denom = (denominator * w).sum(dim=1, keepdim=True) + y_pred_o = y_pred_o.sum(dim=-1, keepdim=True) + else: + numer = 2.0 * (intersection * w) + denom = denominator * w + y_pred_o = y_pred_o # Compute the score generalized_dice_score = numer / denom # Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1. # Where denom == 0 but the prediction volume is not 0, score is 0 - y_pred_o = y_pred_o.sum(dim=-1) denom_zeros = denom == 0 generalized_dice_score[denom_zeros] = torch.where( (y_pred_o == 0)[denom_zeros], diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 4c429ae813..5a240021d6 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -11,7 +11,9 @@ from __future__ import annotations +from .trt_compiler import trt_compile from .utils import ( + add_casts_around_norms, convert_to_onnx, convert_to_torchscript, convert_to_trt, diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 91bd73ebbb..fca566591a 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -21,7 +21,7 @@ from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding from monai.networks.layers import Conv, trunc_normal_ -from monai.utils import deprecated_arg, ensure_tuple_rep, optional_import +from monai.utils import ensure_tuple_rep, optional_import from monai.utils.module import look_up_option Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -42,9 +42,6 @@ class PatchEmbeddingBlock(nn.Module): """ - @deprecated_arg( - name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead." - ) def __init__( self, in_channels: int, @@ -52,7 +49,6 @@ def __init__( patch_size: Sequence[int] | int, hidden_size: int, num_heads: int, - pos_embed: str = "conv", proj_type: str = "conv", pos_embed_type: str = "learnable", dropout_rate: float = 0.0, @@ -69,8 +65,6 @@ def __init__( pos_embed_type: position embedding layer type. dropout_rate: fraction of the input units to drop. spatial_dims: number of spatial dimensions. - .. deprecated:: 1.4 - ``pos_embed`` is deprecated in favor of ``proj_type``. """ super().__init__() diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 0ff1187dcc..c48c77cf98 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -51,6 +51,8 @@ def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True): ctx.cs = color_sigma ctx.fa = fast_approx output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx) + if torch.cuda.is_available(): + torch.cuda.synchronize() return output_data @staticmethod @@ -139,7 +141,8 @@ def forward(ctx, input_img, sigma_x, sigma_y, sigma_z, color_sigma): do_dsig_y, do_dsig_z, ) - + if torch.cuda.is_available(): + torch.cuda.synchronize() return output_tensor @staticmethod @@ -301,7 +304,8 @@ def forward(ctx, input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma do_dsig_z, guidance_img, ) - + if torch.cuda.is_available(): + torch.cuda.synchronize() return output_tensor @staticmethod diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py index 5f340c9be6..3745b66bb5 100644 --- a/monai/networks/nets/hovernet.py +++ b/monai/networks/nets/hovernet.py @@ -43,7 +43,7 @@ from monai.networks.layers.factories import Conv, Dropout from monai.networks.layers.utils import get_act_layer, get_norm_layer from monai.utils.enums import HoVerNetBranch, HoVerNetMode, InterpolateMode, UpsampleMode -from monai.utils.module import export, look_up_option +from monai.utils.module import look_up_option __all__ = ["HoVerNet", "Hovernet", "HoVernet", "HoVerNet"] @@ -409,7 +409,6 @@ def forward(self, xin: torch.Tensor, short_cuts: list[torch.Tensor]) -> torch.Te return x -@export("monai.networks.nets") class HoVerNet(nn.Module): """HoVerNet model diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 3900c866b3..714d986f4b 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -320,7 +320,7 @@ def _check_input_size(self, spatial_shape): ) def forward(self, x_in): - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): self._check_input_size(x_in.shape[2:]) hidden_states_out = self.swinViT(x_in, self.normalize) enc0 = self.encoder1(x_in) @@ -1046,14 +1046,14 @@ def __init__( def proj_out(self, x, normalize=False): if normalize: - x_shape = x.size() + x_shape = x.shape + # Force trace() to generate a constant by casting to int + ch = int(x_shape[1]) if len(x_shape) == 5: - n, ch, d, h, w = x_shape x = rearrange(x, "n c d h w -> n d h w c") x = F.layer_norm(x, [ch]) x = rearrange(x, "n d h w c -> n c d h w") elif len(x_shape) == 4: - n, ch, h, w = x_shape x = rearrange(x, "n c h w -> n h w c") x = F.layer_norm(x, [ch]) x = rearrange(x, "n h w c -> n c h w") diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 7b16b6c923..eac0ddab39 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -20,13 +20,10 @@ from monai.networks.blocks.convolutions import Convolution, ResidualUnit from monai.networks.layers.factories import Act, Norm from monai.networks.layers.simplelayers import SkipConnection -from monai.utils import alias, export __all__ = ["UNet", "Unet"] -@export("monai.networks.nets") -@alias("Unet") class UNet(nn.Module): """ Enhanced version of UNet which has residual units implemented with the ResidualUnit class. diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py index a88e5a92fd..79ea0e23f7 100644 --- a/monai/networks/nets/unetr.py +++ b/monai/networks/nets/unetr.py @@ -18,7 +18,7 @@ from monai.networks.blocks.dynunet_block import UnetOutBlock from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock from monai.networks.nets.vit import ViT -from monai.utils import deprecated_arg, ensure_tuple_rep +from monai.utils import ensure_tuple_rep class UNETR(nn.Module): @@ -27,9 +27,6 @@ class UNETR(nn.Module): UNETR: Transformers for 3D Medical Image Segmentation " """ - @deprecated_arg( - name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead." - ) def __init__( self, in_channels: int, @@ -39,7 +36,6 @@ def __init__( hidden_size: int = 768, mlp_dim: int = 3072, num_heads: int = 12, - pos_embed: str = "conv", proj_type: str = "conv", norm_name: tuple | str = "instance", conv_block: bool = True, @@ -67,9 +63,6 @@ def __init__( qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False. save_attn: to make accessible the attention in self attention block. Defaults to False. - .. deprecated:: 1.4 - ``pos_embed`` is deprecated in favor of ``proj_type``. - Examples:: # for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 9148e36542..4215a9a594 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -336,11 +336,11 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False): def forward( self, input_images: torch.Tensor, + patch_coords: list[Sequence[slice]] | None = None, point_coords: torch.Tensor | None = None, point_labels: torch.Tensor | None = None, class_vector: torch.Tensor | None = None, prompt_class: torch.Tensor | None = None, - patch_coords: Sequence[slice] | None = None, labels: torch.Tensor | None = None, label_set: Sequence[int] | None = None, prev_mask: torch.Tensor | None = None, @@ -364,8 +364,12 @@ def forward( the points are for zero-shot or supported class. When class_vector and point_coords are both provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] will be considered novel class. - patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference. - This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase. + patch_coords: a list of sequence of the python slice objects representing the patch coordinates during sliding window + inference. This value is passed from sliding_window_inferer. + This is an indicator for training phase or validation phase. + Notice for sliding window batch size > 1 (only supported by automatic segmentation), patch_coords will inlcude + coordinates of multiple patches. If point prompts are included, the batch size can only be one and all the + functions using patch_coords will by default use patch_coords[0]. labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot @@ -395,14 +399,14 @@ def forward( if val_point_sampler is None: # TODO: think about how to refactor this part. val_point_sampler = self.sample_points_patch_val - point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set) + point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords[0], label_set) if prompt_class[0].item() == 0: # type: ignore point_labels[0] = -1 # type: ignore labels, prev_mask = None, None elif point_coords is not None: # If not performing patch-based point only validation, use user provided click points for inference. # the point clicks is in original image space, convert it to current patch-coordinate space. - point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore + point_coords, point_labels = self.update_point_to_patch(patch_coords[0], point_coords, point_labels) # type: ignore if point_coords is not None and point_labels is not None: # remove points that used for padding purposes (point_label = -1) @@ -421,7 +425,10 @@ def forward( point_coords, point_labels = None, None if point_coords is None and class_vector is None: - return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device) + logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device) + if transpose: + logits = logits.transpose(1, 0) + return logits if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None: out, out_auto = self.image_embeddings, None @@ -452,7 +459,7 @@ def forward( logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) if prev_mask is not None and patch_coords is not None: logits = self.connected_components_combine( - prev_mask[patch_coords].transpose(1, 0).to(logits.device), + prev_mask[patch_coords[0]].transpose(1, 0).to(logits.device), logits[mapping_index], point_coords, # type: ignore point_labels, # type: ignore diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 4eada6aa76..07c5147cb2 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -18,7 +18,6 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock -from monai.utils import deprecated_arg __all__ = ["ViT"] @@ -31,9 +30,6 @@ class ViT(nn.Module): ViT supports Torchscript but only works for Pytorch after 1.8. """ - @deprecated_arg( - name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead." - ) def __init__( self, in_channels: int, @@ -43,7 +39,6 @@ def __init__( mlp_dim: int = 3072, num_layers: int = 12, num_heads: int = 12, - pos_embed: str = "conv", proj_type: str = "conv", pos_embed_type: str = "learnable", classification: bool = False, @@ -75,9 +70,6 @@ def __init__( qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False. - .. deprecated:: 1.4 - ``pos_embed`` is deprecated in favor of ``proj_type``. - Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py index d69f5df4be..3c20f9a784 100644 --- a/monai/networks/nets/vitautoenc.py +++ b/monai/networks/nets/vitautoenc.py @@ -20,7 +20,7 @@ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock from monai.networks.blocks.transformerblock import TransformerBlock from monai.networks.layers import Conv -from monai.utils import deprecated_arg, ensure_tuple_rep, is_sqrt +from monai.utils import ensure_tuple_rep, is_sqrt __all__ = ["ViTAutoEnc"] @@ -33,9 +33,6 @@ class ViTAutoEnc(nn.Module): Modified to also give same dimension outputs as the input size of the image """ - @deprecated_arg( - name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead." - ) def __init__( self, in_channels: int, @@ -47,7 +44,6 @@ def __init__( mlp_dim: int = 3072, num_layers: int = 12, num_heads: int = 12, - pos_embed: str = "conv", proj_type: str = "conv", dropout_rate: float = 0.0, spatial_dims: int = 3, @@ -71,9 +67,6 @@ def __init__( qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False. - .. deprecated:: 1.4 - ``pos_embed`` is deprecated in favor of ``proj_type``. - Examples:: # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone diff --git a/monai/networks/nets/voxelmorph.py b/monai/networks/nets/voxelmorph.py index 0496cfc8f8..4923b6ad60 100644 --- a/monai/networks/nets/voxelmorph.py +++ b/monai/networks/nets/voxelmorph.py @@ -21,13 +21,10 @@ from monai.networks.blocks.upsample import UpSample from monai.networks.blocks.warp import DVF2DDF, Warp from monai.networks.layers.simplelayers import SkipConnection -from monai.utils import alias, export __all__ = ["VoxelMorphUNet", "voxelmorphunet", "VoxelMorph", "voxelmorph"] -@export("monai.networks.nets") -@alias("voxelmorphunet") class VoxelMorphUNet(nn.Module): """ The backbone network used in VoxelMorph. See :py:class:`monai.networks.nets.VoxelMorph` for more details. @@ -340,8 +337,6 @@ def forward(self, concatenated_pairs: torch.Tensor) -> torch.Tensor: voxelmorphunet = VoxelMorphUNet -@export("monai.networks.nets") -@alias("voxelmorph") class VoxelMorph(nn.Module): """ A re-implementation of VoxelMorph framework for medical image registration as described in diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py new file mode 100644 index 0000000000..00d2eb61af --- /dev/null +++ b/monai/networks/trt_compiler.py @@ -0,0 +1,569 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import os +import tempfile +import threading +from collections import OrderedDict +from pathlib import Path +from types import MethodType +from typing import Any, Dict, List, Union + +import torch + +from monai.apps.utils import get_logger +from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript, get_profile_shapes +from monai.utils.module import optional_import + +polygraphy, polygraphy_imported = optional_import("polygraphy") +if polygraphy_imported: + from polygraphy.backend.common import bytes_from_path + from polygraphy.backend.trt import ( + CreateConfig, + Profile, + engine_bytes_from_network, + engine_from_bytes, + network_from_onnx_path, + ) + +trt, trt_imported = optional_import("tensorrt") +torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") +cudart, _ = optional_import("cuda.cudart") + + +lock_sm = threading.Lock() + + +# Map of TRT dtype -> Torch dtype +def trt_to_torch_dtype_dict(): + return { + trt.int32: torch.int32, + trt.float32: torch.float32, + trt.float16: torch.float16, + trt.bfloat16: torch.float16, + trt.int64: torch.int64, + trt.int8: torch.int8, + trt.bool: torch.bool, + } + + +def get_dynamic_axes(profiles): + """ + This method calculates dynamic_axes to use in onnx.export(). + Args: + profiles: [[min,opt,max],...] list of profile dimensions + """ + dynamic_axes: dict[str, list[int]] = {} + if not profiles: + return dynamic_axes + for profile in profiles: + for key in profile: + axes = [] + vals = profile[key] + for i in range(len(vals[0])): + if vals[0][i] != vals[2][i]: + axes.append(i) + if len(axes) > 0: + dynamic_axes[key] = axes + return dynamic_axes + + +def cuassert(cuda_ret): + """ + Error reporting method for CUDA calls. + Args: + cuda_ret: CUDA return code. + """ + err = cuda_ret[0] + if err != 0: + raise RuntimeError(f"CUDA ERROR: {err}") + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + +class ShapeError(Exception): + """ + Exception class to report errors from setting TRT plan input shapes + """ + + pass + + +class TRTEngine: + """ + An auxiliary class to implement running of TRT optimized engines + + """ + + def __init__(self, plan_path, logger=None): + """ + Loads serialized engine, creates execution context and activates it + Args: + plan_path: path to serialized TRT engine. + logger: optional logger object + """ + self.plan_path = plan_path + self.logger = logger or get_logger("trt_compile") + self.logger.info(f"Loading TensorRT engine: {self.plan_path}") + self.engine = engine_from_bytes(bytes_from_path(self.plan_path)) + self.tensors = OrderedDict() + self.cuda_graph_instance = None # cuda graph + self.context = self.engine.create_execution_context() + self.input_names = [] + self.output_names = [] + self.dtypes = [] + self.cur_profile = 0 + dtype_dict = trt_to_torch_dtype_dict() + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: + self.input_names.append(binding) + elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT: + self.output_names.append(binding) + dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] + self.dtypes.append(dtype) + + def allocate_buffers(self, device): + """ + Allocates outputs to run TRT engine + Args: + device: GPU device to allocate memory on + """ + ctx = self.context + + for i, binding in enumerate(self.output_names): + shape = list(ctx.get_tensor_shape(binding)) + if binding not in self.tensors or list(self.tensors[binding].shape) != shape: + t = torch.empty(shape, dtype=self.dtypes[i], device=device).contiguous() + self.tensors[binding] = t + ctx.set_tensor_address(binding, t.data_ptr()) + + def set_inputs(self, feed_dict, stream): + """ + Sets input bindings for TRT engine according to feed_dict + Args: + feed_dict: a dictionary [str->Tensor] + stream: CUDA stream to use + """ + e = self.engine + ctx = self.context + + last_profile = self.cur_profile + + def try_set_inputs(): + for binding, t in feed_dict.items(): + if t is not None: + t = t.contiguous() + shape = t.shape + ctx.set_input_shape(binding, shape) + ctx.set_tensor_address(binding, t.data_ptr()) + + while True: + try: + try_set_inputs() + break + except ShapeError: + next_profile = (self.cur_profile + 1) % e.num_optimization_profiles + if next_profile == last_profile: + raise + self.cur_profile = next_profile + ctx.set_optimization_profile_async(self.cur_profile, stream) + + left = ctx.infer_shapes() + assert len(left) == 0 + + def infer(self, stream, use_cuda_graph=False): + """ + Runs TRT engine. + Args: + stream: CUDA stream to run on + use_cuda_graph: use CUDA graph. Note: requires all inputs to be the same GPU memory between calls. + """ + if use_cuda_graph: + if self.cuda_graph_instance is not None: + cuassert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + cuassert(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + cuassert( + cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal) + ) + self.context.execute_async_v3(stream) + graph = cuassert(cudart.cudaStreamEndCapture(stream)) + self.cuda_graph_instance = cuassert(cudart.cudaGraphInstantiate(graph, 0)) + self.logger.info("CUDA Graph captured!") + else: + noerror = self.context.execute_async_v3(stream) + cuassert(cudart.cudaStreamSynchronize(stream)) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +class TrtCompiler: + """ + This class implements: + - TRT lazy persistent export + - Running TRT with optional fallback to Torch + (for TRT engines with limited profiles) + """ + + def __init__( + self, + model, + plan_path, + precision="fp16", + method="onnx", + input_names=None, + output_names=None, + export_args=None, + build_args=None, + input_profiles=None, + dynamic_batchsize=None, + use_cuda_graph=False, + timestamp=None, + fallback=False, + logger=None, + ): + """ + Initialization method: + Tries to load persistent serialized TRT engine + Saves its arguments for lazy TRT build on first forward() call + Args: + model: Model to "wrap". + plan_path : Path where to save persistent serialized TRT engine. + precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. + method: One of 'onnx'|'torch_trt'. + Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option. + 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work. + input_names: Optional list of input names. If None, will be read from the function signature. + output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. + export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. + build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. + input_profiles: Optional list of profiles for TRT builder and ONNX export. + Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}. + dynamic_batchsize: A sequence with three elements to define the batch size range of the input for the model to be + converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. + [note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used to build TRT engine. + use_cuda_graph: Use CUDA Graph for inference. Note: all inputs have to be the same GPU memory between calls! + timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes). + fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile). + """ + + method_vals = ["onnx", "torch_trt"] + if method not in method_vals: + raise ValueError(f"trt_compile(): 'method' should be one of {method_vals}, got: {method}.") + precision_vals = ["fp32", "tf32", "fp16", "bf16"] + if precision not in precision_vals: + raise ValueError(f"trt_compile(): 'precision' should be one of {precision_vals}, got: {precision}.") + + self.plan_path = plan_path + self.precision = precision + self.method = method + self.return_dict = output_names is not None + self.output_names = output_names or [] + self.profiles = input_profiles or [] + self.dynamic_batchsize = dynamic_batchsize + self.export_args = export_args or {} + self.build_args = build_args or {} + self.engine: TRTEngine | None = None + self.use_cuda_graph = use_cuda_graph + self.fallback = fallback + self.disabled = False + + self.logger = logger or get_logger("trt_compile") + + # Normally we read input_names from forward() but can be overridden + if input_names is None: + argspec = inspect.getfullargspec(model.forward) + input_names = argspec.args[1:] + self.input_names = input_names + self.old_forward = model.forward + + # Force engine rebuild if older than the timestamp + if timestamp is not None and os.path.exists(self.plan_path) and os.path.getmtime(self.plan_path) < timestamp: + os.remove(self.plan_path) + + def _inputs_to_dict(self, input_example): + trt_inputs = {} + for i, inp in enumerate(input_example): + input_name = self.input_names[i] + trt_inputs[input_name] = inp + return trt_inputs + + def _load_engine(self): + """ + Loads TRT plan from disk and activates its execution context. + """ + try: + self.engine = TRTEngine(self.plan_path, self.logger) + self.input_names = self.engine.input_names + except Exception as e: + self.logger.debug(f"Exception while loading the engine:\n{e}") + + def forward(self, model, argv, kwargs): + """ + Main forward method: + Builds TRT engine if not available yet. + Tries to run TRT engine + If exception thrown and self.callback==True: falls back to original Pytorch + + Args: Passing through whatever args wrapped module's forward() has + Returns: Passing through wrapped module's forward() return value(s) + + """ + if self.engine is None and not self.disabled: + # Restore original forward for export + new_forward = model.forward + model.forward = self.old_forward + try: + self._load_engine() + if self.engine is None: + build_args = kwargs.copy() + if len(argv) > 0: + build_args.update(self._inputs_to_dict(argv)) + self._build_and_save(model, build_args) + # This will reassign input_names from the engine + self._load_engine() + assert self.engine is not None + except Exception as e: + if self.fallback: + self.logger.info(f"Failed to build engine: {e}") + self.disabled = True + else: + raise e + if not self.disabled and not self.fallback: + # Delete all parameters + for param in model.parameters(): + del param + # Call empty_cache to release GPU memory + torch.cuda.empty_cache() + model.forward = new_forward + # Run the engine + try: + if len(argv) > 0: + kwargs.update(self._inputs_to_dict(argv)) + argv = () + + if self.engine is not None: + # forward_trt is not thread safe as we do not use per-thread execution contexts + with lock_sm: + device = torch.cuda.current_device() + stream = torch.cuda.Stream(device=device) + self.engine.set_inputs(kwargs, stream.cuda_stream) + self.engine.allocate_buffers(device=device) + # Need this to synchronize with Torch stream + stream.wait_stream(torch.cuda.current_stream()) + ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) + # if output_names is not None, return dictionary + if not self.return_dict: + ret = list(ret.values()) + if len(ret) == 1: + ret = ret[0] + return ret + except Exception as e: + if model is not None: + self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...") + else: + raise e + return self.old_forward(*argv, **kwargs) + + def _onnx_to_trt(self, onnx_path): + """ + Builds TRT engine from ONNX file at onnx_path and saves to self.plan_path + """ + + profiles = [] + if self.profiles: + for input_profile in self.profiles: + if isinstance(input_profile, Profile): + profiles.append(input_profile) + else: + p = Profile() + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + profiles.append(p) + + build_args = self.build_args.copy() + build_args["tf32"] = self.precision != "fp32" + if self.precision == "fp16": + build_args["fp16"] = True + elif self.precision == "bf16": + build_args["bf16"] = True + + self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}") + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) + + def _build_and_save(self, model, input_example): + """ + If TRT engine is not ready, exports model to ONNX, + builds TRT engine and saves serialized TRT engine to the disk. + Args: + input_example: passed to onnx.export() + """ + + if self.engine is not None: + return + + export_args = self.export_args + + add_casts_around_norms(model) + + if self.method == "torch_trt": + enabled_precisions = [torch.float32] + if self.precision == "fp16": + enabled_precisions.append(torch.float16) + elif self.precision == "bf16": + enabled_precisions.append(torch.bfloat16) + inputs = list(input_example.values()) + ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True) + + def get_torch_trt_input(input_shape, dynamic_batchsize): + min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) + return torch_tensorrt.Input( + min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape + ) + + tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs] + engine_bytes = torch_tensorrt.convert_method_to_trt_engine( + ir_model, + "forward", + inputs=tt_inputs, + ir="torchscript", + enabled_precisions=enabled_precisions, + **export_args, + ) + else: + dbs = self.dynamic_batchsize + if dbs: + if len(self.profiles) > 0: + raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") + if len(dbs) != 3: + raise ValueError("dynamic_batchsize has to have len ==3 ") + profiles = {} + for id, val in input_example.items(): + sh = val.shape[1:] + profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + self.profiles = [profiles] + + if len(self.profiles) > 0: + export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)}) + + # Use temporary directory for easy cleanup in case of external weights + with tempfile.TemporaryDirectory() as tmpdir: + onnx_path = Path(tmpdir) / "model.onnx" + self.logger.info( + f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}" + ) + convert_to_onnx( + model, + input_example, + filename=str(onnx_path), + input_names=self.input_names, + output_names=self.output_names, + **export_args, + ) + self.logger.info("Export to ONNX successful.") + engine_bytes = self._onnx_to_trt(str(onnx_path)) + + open(self.plan_path, "wb").write(engine_bytes) + + +def trt_forward(self, *argv, **kwargs): + """ + Patch function to replace original model's forward() with. + Redirects to TrtCompiler.forward() + """ + return self._trt_compiler.forward(self, argv, kwargs) + + +def trt_compile( + model: torch.nn.Module, + base_path: str, + args: Dict[str, Any] | None = None, + submodule: Union[str, List[str]] | None = None, + logger: Any | None = None, +) -> torch.nn.Module: + """ + Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. + Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x + Args: + model: module to patch with TrtCompiler object. + base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. + dirname(base_path) must exist, base_path does not have to. + If base_path does point to existing file (e.g. associated checkpoint), + that file becomes a dependency - its mtime is added to args["timestamp"]. + args: Optional dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details. + submodule: Optional hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder'] + If None, TrtCompiler patch is applied to the whole model. + Otherwise, submodule (or list of) is being patched. + logger: Optional logger for diagnostics. + Returns: + Always returns same model passed in as argument. This is for ease of use in configs. + """ + + default_args: Dict[str, Any] = { + "method": "onnx", + "precision": "fp16", + "build_args": {"builder_optimization_level": 5, "precision_constraints": "obey"}, + } + + default_args.update(args or {}) + args = default_args + + if trt_imported and polygraphy_imported and torch.cuda.is_available(): + # if "path" filename point to existing file (e.g. checkpoint) + # it's also treated as dependency + if os.path.exists(base_path): + timestamp = int(os.path.getmtime(base_path)) + if "timestamp" in args: + timestamp = max(int(args["timestamp"]), timestamp) + args["timestamp"] = timestamp + + def wrap(model, path): + wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) + model._trt_compiler = wrapper + model.forward = MethodType(trt_forward, model) + + def find_sub(parent, submodule): + idx = submodule.find(".") + # if there is "." in name, call recursively + if idx != -1: + parent_name = submodule[:idx] + parent = getattr(parent, parent_name) + submodule = submodule[idx + 1 :] + return find_sub(parent, submodule) + return parent, submodule + + if submodule is not None: + if isinstance(submodule, str): + submodule = [submodule] + for s in submodule: + parent, sub = find_sub(model, s) + wrap(getattr(parent, sub), base_path + "." + s) + else: + wrap(model, base_path) + else: + logger = logger or get_logger("trt_compile") + logger.warning("TensorRT and/or polygraphy packages are not available! trt_compile() has no effect.") + + return model diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f301c2dd5c..d0150b4e5b 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -36,6 +36,8 @@ onnx, _ = optional_import("onnx") onnxreference, _ = optional_import("onnx.reference") onnxruntime, _ = optional_import("onnxruntime") +polygraphy, polygraphy_imported = optional_import("polygraphy") +torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") __all__ = [ "one_hot", @@ -61,6 +63,7 @@ "look_up_named_module", "set_named_module", "has_nvfuser_instance_norm", + "get_profile_shapes", ] logger = get_logger(module_name=__name__) @@ -68,6 +71,26 @@ _has_nvfuser = None +def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None): + """ + Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize. + """ + + def scale_batch_size(input_shape: Sequence[int], scale_num: int): + scale_shape = [*input_shape] + scale_shape[0] = scale_num + return scale_shape + + # Use the dynamic batchsize range to generate the min, opt and max model input shape + if dynamic_batchsize: + min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0]) + opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1]) + max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2]) + else: + min_input_shape = opt_input_shape = max_input_shape = input_shape + return min_input_shape, opt_input_shape, max_input_shape + + def has_nvfuser_instance_norm(): """whether the current environment has InstanceNorm3dNVFuser https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16 @@ -606,6 +629,9 @@ def convert_to_onnx( rtol: float = 1e-4, atol: float = 0.0, use_trace: bool = True, + do_constant_folding: bool = True, + constant_size_threshold: int = 16 * 1024 * 1024 * 1024, + dynamo=False, **kwargs, ): """ @@ -632,7 +658,10 @@ def convert_to_onnx( rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model. atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model. use_trace: whether to use `torch.jit.trace` to export the torchscript model. - kwargs: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: + do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done. + constant_size_threshold: passed to polygrapy conatant forling, default = 16M + kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export() + else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details: https://pytorch.org/docs/master/generated/torch.jit.script.html. """ @@ -642,6 +671,7 @@ def convert_to_onnx( if use_trace: # let torch.onnx.export to trace the model. mode_to_export = model + torch_versioned_kwargs = kwargs else: if not pytorch_after(1, 10): if "example_outputs" not in kwargs: @@ -654,32 +684,37 @@ def convert_to_onnx( del kwargs["example_outputs"] mode_to_export = torch.jit.script(model, **kwargs) + if torch.is_tensor(inputs) or isinstance(inputs, dict): + onnx_inputs = (inputs,) + else: + onnx_inputs = tuple(inputs) + if filename is None: f = io.BytesIO() - torch.onnx.export( - mode_to_export, - tuple(inputs), - f=f, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=opset_version, - **torch_versioned_kwargs, - ) + else: + f = filename + + torch.onnx.export( + mode_to_export, + onnx_inputs, + f=f, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=opset_version, + do_constant_folding=do_constant_folding, + **torch_versioned_kwargs, + ) + if filename is None: onnx_model = onnx.load_model_from_string(f.getvalue()) else: - torch.onnx.export( - mode_to_export, - tuple(inputs), - f=filename, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=opset_version, - **torch_versioned_kwargs, - ) onnx_model = onnx.load(filename) + if do_constant_folding and polygraphy_imported: + from polygraphy.backend.onnx.loader import fold_constants + + fold_constants(onnx_model, size_threshold=constant_size_threshold) + if verify: if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -814,7 +849,6 @@ def _onnx_trt_compile( """ trt, _ = optional_import("tensorrt", "8.5.3") - torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0") input_shapes = (min_shape, opt_shape, max_shape) # default to an empty list to fit the `torch_tensorrt.ts.embed_engine_in_new_module` function. @@ -851,7 +885,7 @@ def _onnx_trt_compile( # wrap the serialized TensorRT engine back to a TorchScript module. trt_model = torch_tensorrt.ts.embed_engine_in_new_module( f.getvalue(), - device=torch.device(f"cuda:{device}"), + device=torch_tensorrt.Device(f"cuda:{device}"), input_binding_names=input_names, output_binding_names=output_names, ) @@ -916,8 +950,6 @@ def convert_to_trt( to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py. """ - torch_tensorrt, _ = optional_import("torch_tensorrt", version="1.4.0") - if not torch.cuda.is_available(): raise Exception("Cannot find any GPU devices.") @@ -935,23 +967,9 @@ def convert_to_trt( convert_precision = torch.float32 if precision == "fp32" else torch.half inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)] - def scale_batch_size(input_shape: Sequence[int], scale_num: int): - scale_shape = [*input_shape] - scale_shape[0] *= scale_num - return scale_shape - - # Use the dynamic batchsize range to generate the min, opt and max model input shape - if dynamic_batchsize: - min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0]) - opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1]) - max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2]) - else: - min_input_shape = opt_input_shape = max_input_shape = input_shape - # convert the torch model to a TorchScript model on target device model = model.eval().to(target_device) - ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace) - ir_model.eval() + min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) if use_onnx: # set the batch dim as dynamic @@ -960,7 +978,6 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): ir_model = convert_to_onnx( model, inputs, onnx_input_names, onnx_output_names, use_trace=use_trace, dynamic_axes=dynamic_axes ) - # convert the model through the ONNX-TensorRT way trt_model = _onnx_trt_compile( ir_model, @@ -973,6 +990,8 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): output_names=onnx_output_names, ) else: + ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace) + ir_model.eval() # convert the model through the Torch-TensorRT way ir_model.to(target_device) with torch.no_grad(): @@ -1189,3 +1208,168 @@ def forward(self, x): if dtype == self.initial_type: x = x.to(self.initial_type) return x + + +def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32): + """ + Utility function to cast a single tensor from from_dtype to to_dtype + """ + return x.to(dtype=to_dtype) if x.dtype == from_dtype else x + + +def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32): + """ + Utility function to cast all tensors in a tuple from from_dtype to to_dtype + """ + if isinstance(x, torch.Tensor): + return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype) + else: + if isinstance(x, dict): + new_dict = {} + for k in x.keys(): + new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype) + return new_dict + elif isinstance(x, tuple): + return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x) + + +class CastToFloat(torch.nn.Module): + """ + Class used to add autocast protection for ONNX export + for forward methods with single return vaue + """ + + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, x): + dtype = x.dtype + with torch.amp.autocast("cuda", enabled=False): + ret = self.mod.forward(x.to(torch.float32)).to(dtype) + return ret + + +class CastToFloatAll(torch.nn.Module): + """ + Class used to add autocast protection for ONNX export + for forward methods with multiple return values + """ + + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, *args): + from_dtype = args[0].dtype + with torch.amp.autocast("cuda", enabled=False): + ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) + return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) + + +def wrap_module(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: + """ + Generic function generator to replace base_t module with dest_t wrapper. + Args: + base_t : module type to replace + dest_t : destination module type + Returns: + swap function to replace base_t module with dest_t + """ + + def expansion_fn(mod: nn.Module) -> nn.Module | None: + out = dest_t(mod) + return out + + return expansion_fn + + +def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]: + """ + Generic function generator to replace base_t module with dest_t. + base_t and dest_t should have same atrributes. No weights are copied. + Args: + base_t : module type to replace + dest_t : destination module type + Returns: + swap function to replace base_t module with dest_t + """ + + def expansion_fn(mod: nn.Module) -> nn.Module | None: + if not isinstance(mod, base_t): + return None + args = [getattr(mod, name, None) for name in mod.__constants__] + out = dest_t(*args) + return out + + return expansion_fn + + +def _swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module: + """ + This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows + for swapping nested modules through arbitrary levels if children + + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + + """ + for path, new_mod in mapping.items(): + expanded_path = path.split(".") + parent_mod = model + for sub_path in expanded_path[:-1]: + submod = parent_mod._modules[sub_path] + if submod is None: + break + else: + parent_mod = submod + parent_mod._modules[expanded_path[-1]] = new_mod + + return model + + +def replace_modules_by_type( + model: nn.Module, expansions: dict[str, Callable[[nn.Module], nn.Module | None]] +) -> nn.Module: + """ + Top-level function to replace modules in model, specified by class name with a desired replacement. + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + Args: + model : top level module + expansions : replacement dictionary: module class name -> replacement function generator + Returns: + model, possibly modified in-place + """ + mapping: dict[str, nn.Module] = {} + for name, m in model.named_modules(): + m_type = type(m).__name__ + if m_type in expansions: + # print (f"Found {m_type} in expansions ...") + swapped = expansions[m_type](m) + if swapped: + mapping[name] = swapped + + print(f"Swapped {len(mapping)} modules") + _swap_modules(model, mapping) + return model + + +def add_casts_around_norms(model: nn.Module) -> nn.Module: + """ + Top-level function to add cast wrappers around modules known to cause issues for FP16/autocast ONNX export + NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. + Args: + model : top level module + Returns: + model, possibly modified in-place + """ + print("Adding casts around norms...") + cast_replacements = { + "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), + "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), + "BatchNorm3d": wrap_module(nn.BatchNorm2d, CastToFloat), + "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), + "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat), + "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat), + } + replace_modules_by_type(model, cast_replacements) + return model diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index f37016e63f..2cdd965c91 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -396,6 +396,8 @@ from .spatial.array import ( Affine, AffineGrid, + ConvertBoxToPoints, + ConvertPointsToBoxes, Flip, GridDistortion, GridPatch, @@ -427,6 +429,12 @@ Affined, AffineD, AffineDict, + ConvertBoxToPointsd, + ConvertBoxToPointsD, + ConvertBoxToPointsDict, + ConvertPointsToBoxesd, + ConvertPointsToBoxesD, + ConvertPointsToBoxesDict, Flipd, FlipD, FlipDict, @@ -503,6 +511,7 @@ from .utility.array import ( AddCoordinateChannels, AddExtremePointsChannel, + ApplyTransformToPoints, AsChannelLast, CastToType, ClassesToIndices, @@ -542,6 +551,9 @@ AddExtremePointsChanneld, AddExtremePointsChannelD, AddExtremePointsChannelDict, + ApplyTransformToPointsd, + ApplyTransformToPointsD, + ApplyTransformToPointsDict, AsChannelLastd, AsChannelLastD, AsChannelLastDict, diff --git a/monai/transforms/adaptors.py b/monai/transforms/adaptors.py index f5f1a4fc18..5a0c24c7f6 100644 --- a/monai/transforms/adaptors.py +++ b/monai/transforms/adaptors.py @@ -125,12 +125,9 @@ def __call__(self, img, seg): from typing import Callable -from monai.utils import export as _monai_export - __all__ = ["adaptor", "apply_alias", "to_kwargs", "FunctionSignature"] -@_monai_export("monai.transforms") def adaptor(function, outputs, inputs=None): def must_be_types_or_none(variable_name, variable, types): @@ -215,7 +212,6 @@ def _inner(ditems): return _inner -@_monai_export("monai.transforms") def apply_alias(fn, name_map): def _inner(data): @@ -236,7 +232,6 @@ def _inner(data): return _inner -@_monai_export("monai.transforms") def to_kwargs(fn): def _inner(data): diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 3b813809e4..20000c52c4 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1411,7 +1411,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: else: img_t = self._normalize(img=img_t) - return convert_to_dst_type(img_t, dst=img)[0] + return convert_to_dst_type(img_t, dst=img, dtype=self.dtype)[0] class MaskIntensity(Transform): diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 3739a83e71..6e39fb2e19 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -25,6 +25,7 @@ from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor +from monai.data.box_utils import BoxMode, StandardMode from monai.data.meta_obj import get_track_meta, set_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine @@ -34,6 +35,8 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.functional import ( affine_func, + convert_box_to_points, + convert_points_to_box, flip, orientation, resize, @@ -3544,3 +3547,44 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: else: return img + + +class ConvertBoxToPoints(Transform): + """ + Converts an axis-aligned bounding box to points. It can automatically convert the boxes to the points based on the box mode. + Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D for each box. + Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None: + """ + Args: + mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode. + """ + super().__init__() + self.mode = StandardMode if mode is None else mode + + def __call__(self, data: Any): + data = convert_to_tensor(data, track_meta=get_track_meta()) + points = convert_box_to_points(data, mode=self.mode) + return convert_to_dst_type(points, data)[0] + + +class ConvertPointsToBoxes(Transform): + """ + Converts points to an axis-aligned bounding box. + Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of a 3D cuboid or + (N, 4, 2) for the 4 corners of a 2D rectangle. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self) -> None: + super().__init__() + + def __call__(self, data: Any): + data = convert_to_tensor(data, track_meta=get_track_meta()) + box = convert_points_to_box(data) + return convert_to_dst_type(box, data)[0] diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 01fadcfb69..2b80034a07 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -26,6 +26,7 @@ from monai.config import DtypeLike, KeysCollection, SequenceStr from monai.config.type_definitions import NdarrayOrTensor +from monai.data.box_utils import BoxMode, StandardMode from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.networks.layers.simplelayers import GaussianFilter @@ -33,6 +34,8 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import ( Affine, + ConvertBoxToPoints, + ConvertPointsToBoxes, Flip, GridDistortion, GridPatch, @@ -2585,6 +2588,7 @@ def set_random_state( self, seed: int | None = None, state: np.random.RandomState | None = None ) -> RandSimulateLowResolutiond: super().set_random_state(seed, state) + self.sim_lowres_tfm.set_random_state(seed, state) return self def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: @@ -2611,6 +2615,61 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class ConvertBoxToPointsd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`. + """ + + backend = ConvertBoxToPoints.backend + + def __init__( + self, + keys: KeysCollection, + point_key="points", + mode: str | BoxMode | type[BoxMode] | None = StandardMode, + allow_missing_keys: bool = False, + ): + """ + Args: + keys: keys of the corresponding items to be transformed. + point_key: key to store the point data. + mode: the mode of the input boxes. Defaults to StandardMode. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.point_key = point_key + self.converter = ConvertBoxToPoints(mode=mode) + + def __call__(self, data): + d = dict(data) + for key in self.key_iterator(d): + data[self.point_key] = self.converter(d[key]) + return data + + +class ConvertPointsToBoxesd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ConvertPointsToBoxes`. + """ + + def __init__(self, keys: KeysCollection, box_key="box", allow_missing_keys: bool = False): + """ + Args: + keys: keys of the corresponding items to be transformed. + box_key: key to store the box data. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.box_key = box_key + self.converter = ConvertPointsToBoxes() + + def __call__(self, data): + d = dict(data) + for key in self.key_iterator(d): + data[self.box_key] = self.converter(d[key]) + return data + + SpatialResampleD = SpatialResampleDict = SpatialResampled ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd SpacingD = SpacingDict = Spacingd @@ -2635,3 +2694,5 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N GridPatchD = GridPatchDict = GridPatchd RandGridPatchD = RandGridPatchDict = RandGridPatchd RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond +ConvertBoxToPointsD = ConvertBoxToPointsDict = ConvertBoxToPointsd +ConvertPointsToBoxesD = ConvertPointsToBoxesDict = ConvertPointsToBoxesd diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 22726f06a5..b693e7d023 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -24,6 +24,7 @@ import monai from monai.config import USE_COMPILED from monai.config.type_definitions import NdarrayOrTensor +from monai.data.box_utils import get_boxmode from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd @@ -32,7 +33,7 @@ from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import TraceableTransform from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine -from monai.transforms.utils_pytorch_numpy_unification import allclose +from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack from monai.utils import ( LazyAttr, TraceKeys, @@ -610,3 +611,71 @@ def affine_func( out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device) out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out return out if image_only else (out, affine) + + +def convert_box_to_points(bbox, mode): + """ + Converts an axis-aligned bounding box to points. + + Args: + mode: The mode specifying how to interpret the bounding box. + bbox: Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] + for 3D for each box. Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D. + + Returns: + sequence of points representing the corners of the bounding box. + """ + + mode = get_boxmode(mode) + + points_list = [] + for _num in range(bbox.shape[0]): + corners = mode.boxes_to_corners(bbox[_num : _num + 1]) + if len(corners) == 4: + points_list.append( + concatenate( + [ + concatenate([corners[0], corners[1]], axis=1), + concatenate([corners[2], corners[1]], axis=1), + concatenate([corners[2], corners[3]], axis=1), + concatenate([corners[0], corners[3]], axis=1), + ], + axis=0, + ) + ) + else: + points_list.append( + concatenate( + [ + concatenate([corners[0], corners[1], corners[2]], axis=1), + concatenate([corners[3], corners[1], corners[2]], axis=1), + concatenate([corners[3], corners[4], corners[2]], axis=1), + concatenate([corners[0], corners[4], corners[2]], axis=1), + concatenate([corners[0], corners[1], corners[5]], axis=1), + concatenate([corners[3], corners[1], corners[5]], axis=1), + concatenate([corners[3], corners[4], corners[5]], axis=1), + concatenate([corners[0], corners[4], corners[5]], axis=1), + ], + axis=0, + ) + ) + + return stack(points_list, dim=0) + + +def convert_points_to_box(points): + """ + Converts points to an axis-aligned bounding box. + + Args: + points: Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of + a 3D cuboid or (N, 4, 2) for the 4 corners of a 2D rectangle. + """ + from monai.transforms.utils_pytorch_numpy_unification import max, min + + mins = min(points, dim=1) + maxs = max(points, dim=1) + # Concatenate the min and max values to get the bounding boxes + bboxes = concatenate([mins, maxs], axis=1) + + return bboxes diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 5dfbcb0e91..72dd189009 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -31,7 +31,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor -from monai.data.utils import is_no_channel, no_collation +from monai.data.utils import is_no_channel, no_collation, orientation_ras_lps from monai.networks.layers.simplelayers import ( ApplyFilter, EllipticalFilter, @@ -42,16 +42,17 @@ SharpenFilter, median_filter, ) -from monai.transforms.inverse import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform, TraceableTransform from monai.transforms.traits import MultiSampleTrait from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform from monai.transforms.utils import ( + apply_affine_to_points, extreme_points_to_image, get_extreme_points, map_binary_to_indices, map_classes_to_indices, ) -from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices +from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, linalg_inv, moveaxis, unravel_indices from monai.utils import ( MetaKeys, TraceKeys, @@ -66,7 +67,7 @@ ) from monai.utils.enums import TransformBackends from monai.utils.misc import is_module_ver_at_least -from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype +from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -106,6 +107,7 @@ "ToCupy", "ImageFilter", "RandImageFilter", + "ApplyTransformToPoints", ] @@ -654,6 +656,7 @@ def __init__( data_shape: bool = True, value_range: bool = True, data_value: bool = False, + meta_info: bool = False, additional_info: Callable | None = None, name: str = "DataStats", ) -> None: @@ -665,6 +668,7 @@ def __init__( value_range: whether to show the value range of input data. data_value: whether to show the raw value of input data. a typical example is to print some properties of Nifti image: affine, pixdim, etc. + meta_info: whether to show the data of MetaTensor. additional_info: user can define callable function to extract additional info from input data. name: identifier of `logging.logger` to use, defaulting to "DataStats". @@ -679,6 +683,7 @@ def __init__( self.data_shape = data_shape self.value_range = value_range self.data_value = data_value + self.meta_info = meta_info if additional_info is not None and not callable(additional_info): raise TypeError(f"additional_info must be None or callable but is {type(additional_info).__name__}.") self.additional_info = additional_info @@ -705,6 +710,7 @@ def __call__( data_shape: bool | None = None, value_range: bool | None = None, data_value: bool | None = None, + meta_info: bool | None = None, additional_info: Callable | None = None, ) -> NdarrayOrTensor: """ @@ -725,6 +731,9 @@ def __call__( lines.append(f"Value range: (not a PyTorch or Numpy array, type: {type(img)})") if self.data_value if data_value is None else data_value: lines.append(f"Value: {img}") + if self.meta_info if meta_info is None else meta_info: + metadata = getattr(img, "meta", "(input is not a MetaTensor)") + lines.append(f"Meta info: {repr(metadata)}") additional_info = self.additional_info if additional_info is None else additional_info if additional_info is not None: lines.append(f"Additional info: {additional_info(img)}") @@ -1715,3 +1724,143 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Mapping | None = None) -> Nd if self._do_transform: img = self.filter(img) return img + + +class ApplyTransformToPoints(InvertibleTransform, Transform): + """ + Transform points between image coordinates and world coordinates. + The input coordinates are assumed to be in the shape (C, N, 2 or 3), where C represents the number of channels + and N denotes the number of points. It will return a tensor with the same shape as the input. + + Args: + dtype: The desired data type for the output. + affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates + from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary + Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when + applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly. + The matrix is always converted to float64 for computation, which can be computationally + expensive when applied to a large number of points. + If None, will try to use the affine matrix from the input data. + invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``. + Typically, the affine matrix is derived from an image and represents its location in world space, + while the points are in world coordinates. A value of ``True`` represents transforming these + world space coordinates to the image's coordinate space, and ``False`` the inverse of this operation. + affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system + or you're using `ITKReader` with `affine_lps_to_ras=True`. + This ensures the correct application of the affine transformation between LPS (left-posterior-superior) + and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine + matrix are in the same coordinate system. + + Use Cases: + - Transforming points between world space and image space, and vice versa. + - Automatically handling inverse transformations between image space and world space. + - If points have an existing affine transformation, the class computes and + applies the required delta affine transformation. + + """ + + def __init__( + self, + dtype: DtypeLike | torch.dtype | None = None, + affine: torch.Tensor | None = None, + invert_affine: bool = True, + affine_lps_to_ras: bool = False, + ) -> None: + self.dtype = dtype + self.affine = affine + self.invert_affine = invert_affine + self.affine_lps_to_ras = affine_lps_to_ras + + def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tensor | None = None) -> torch.Tensor: + """ + Compute the final affine transformation matrix to apply to the point data. + + Args: + data: Input coordinates assumed to be in the shape (C, N, 2 or 3). + affine: 3x3 or 4x4 affine transformation matrix. + + Returns: + Final affine transformation matrix. + """ + + affine = convert_data_type(affine, dtype=torch.float64)[0] + + if self.affine_lps_to_ras: + affine = orientation_ras_lps(affine) + + if self.invert_affine: + affine = linalg_inv(affine) + if applied_affine is not None: + affine = affine @ applied_affine + + return affine + + def transform_coordinates( + self, data: torch.Tensor, affine: torch.Tensor | None = None + ) -> tuple[torch.Tensor, dict]: + """ + Transform coordinates using an affine transformation matrix. + + Args: + data: The input coordinates are assumed to be in the shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + affine: 3x3 or 4x4 affine transformation matrix. The matrix is always converted to float64 for computation, + which can be computationally expensive when applied to a large number of points. + + Returns: + Transformed coordinates. + """ + data = convert_to_tensor(data, track_meta=get_track_meta()) + if affine is None and self.invert_affine: + raise ValueError("affine must be provided when invert_affine is True.") + # applied_affine is the affine transformation matrix that has already been applied to the point data + applied_affine: torch.Tensor | None = getattr(data, "affine", None) + affine = applied_affine if affine is None else affine + if affine is None: + raise ValueError("affine must be provided if data does not have an affine matrix.") + + final_affine = self._compute_final_affine(affine, applied_affine) + out = apply_affine_to_points(data, final_affine, dtype=self.dtype) + + extra_info = { + "invert_affine": self.invert_affine, + "dtype": get_dtype_string(self.dtype), + "image_affine": affine, + "affine_lps_to_ras": self.affine_lps_to_ras, + } + + xform = orientation_ras_lps(linalg_inv(final_affine)) if self.affine_lps_to_ras else linalg_inv(final_affine) + meta_info = TraceableTransform.track_transform_meta( + data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info() + ) + + return out, meta_info + + def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None): + """ + Args: + data: The input coordinates are assumed to be in the shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + affine: A 3x3 or 4x4 affine transformation matrix, this argument will take precedence over ``self.affine``. + """ + if data.ndim != 3 or data.shape[-1] not in (2, 3): + raise ValueError(f"data should be in shape (C, N, 2 or 3), got {data.shape}.") + affine = self.affine if affine is None else affine + if affine is not None and affine.shape not in ((3, 3), (4, 4)): + raise ValueError(f"affine should be in shape (3, 3) or (4, 4), got {affine.shape}.") + + out, meta_info = self.transform_coordinates(data, affine) + + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + def inverse(self, data: torch.Tensor) -> torch.Tensor: + transform = self.pop_transform(data) + inverse_transform = ApplyTransformToPoints( + dtype=transform[TraceKeys.EXTRA_INFO]["dtype"], + invert_affine=not transform[TraceKeys.EXTRA_INFO]["invert_affine"], + affine_lps_to_ras=transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"], + ) + with inverse_transform.trace_transform(False): + data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"]) + + return data diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7e3a7b0454..79d0be522d 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -35,6 +35,7 @@ from monai.transforms.utility.array import ( AddCoordinateChannels, AddExtremePointsChannel, + ApplyTransformToPoints, AsChannelLast, CastToType, ClassesToIndices, @@ -180,6 +181,9 @@ "ClassesToIndicesd", "ClassesToIndicesD", "ClassesToIndicesDict", + "ApplyTransformToPointsd", + "ApplyTransformToPointsD", + "ApplyTransformToPointsDict", ] DEFAULT_POST_FIX = PostFix.meta() @@ -789,6 +793,7 @@ def __init__( data_shape: Sequence[bool] | bool = True, value_range: Sequence[bool] | bool = True, data_value: Sequence[bool] | bool = False, + meta_info: Sequence[bool] | bool = False, additional_info: Sequence[Callable] | Callable | None = None, name: str = "DataStats", allow_missing_keys: bool = False, @@ -808,6 +813,8 @@ def __init__( data_value: whether to show the raw value of input data. it also can be a sequence of bool, each element corresponds to a key in ``keys``. a typical example is to print some properties of Nifti image: affine, pixdim, etc. + meta_info: whether to show the data of MetaTensor. + it also can be a sequence of bool, each element corresponds to a key in ``keys``. additional_info: user can define callable function to extract additional info from input data. it also can be a sequence of string, each element corresponds to a key in ``keys``. @@ -821,15 +828,34 @@ def __init__( self.data_shape = ensure_tuple_rep(data_shape, len(self.keys)) self.value_range = ensure_tuple_rep(value_range, len(self.keys)) self.data_value = ensure_tuple_rep(data_value, len(self.keys)) + self.meta_info = ensure_tuple_rep(meta_info, len(self.keys)) self.additional_info = ensure_tuple_rep(additional_info, len(self.keys)) self.printer = DataStats(name=name) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator( - d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info + for ( + key, + prefix, + data_type, + data_shape, + value_range, + data_value, + meta_info, + additional_info, + ) in self.key_iterator( + d, + self.prefix, + self.data_type, + self.data_shape, + self.value_range, + self.data_value, + self.meta_info, + self.additional_info, ): - d[key] = self.printer(d[key], prefix, data_type, data_shape, value_range, data_value, additional_info) + d[key] = self.printer( + d[key], prefix, data_type, data_shape, value_range, data_value, meta_info, additional_info + ) return d @@ -1714,6 +1740,10 @@ class RandImageFilterd(MapTransform, RandomizableTransform): Probability the transform is applied to the data allow_missing_keys: Don't raise exception if key is missing. + + Note: + - This transform does not scale output image values automatically to match the range of the input. + The output should be scaled by later transforms to match the input if this is desired. """ backend = ImageFilter.backend @@ -1740,6 +1770,77 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class ApplyTransformToPointsd(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ApplyTransformToPoints`. + The input coordinates are assumed to be in the shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + The output has the same shape as the input. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + refer_keys: The key of the reference item used for transformation. + It can directly refer to an affine or an image from which the affine can be derived. It can also be a + sequence of keys, in which case each refers to the affine applied to the matching points in `keys`. + dtype: The desired data type for the output. + affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates + from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary + Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when + applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly. + The matrix is always converted to float64 for computation, which can be computationally + expensive when applied to a large number of points. + If None, will try to use the affine matrix from the refer data. + invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``. + Typically, the affine matrix is derived from the image, while the points are in world coordinates. + If you want to align the points with the image, set this to ``True``. Otherwise, set it to ``False``. + affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system + or you're using `ITKReader` with `affine_lps_to_ras=True`. + This ensures the correct application of the affine transformation between LPS (left-posterior-superior) + and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine + matrix are in the same coordinate system. + allow_missing_keys: Don't raise exception if key is missing. + """ + + def __init__( + self, + keys: KeysCollection, + refer_keys: KeysCollection | None = None, + dtype: DtypeLike | torch.dtype = torch.float64, + affine: torch.Tensor | None = None, + invert_affine: bool = True, + affine_lps_to_ras: bool = False, + allow_missing_keys: bool = False, + ): + MapTransform.__init__(self, keys, allow_missing_keys) + self.refer_keys = ensure_tuple_rep(refer_keys, len(self.keys)) + self.converter = ApplyTransformToPoints( + dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + ) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]): + d = dict(data) + for key, refer_key in self.key_iterator(d, self.refer_keys): + coords = d[key] + affine = None # represents using affine given in constructor + if refer_key is not None: + if refer_key in d: + refer_data = d[refer_key] + else: + raise KeyError(f"The refer_key '{refer_key}' is not found in the data.") + + # use the "affine" member of refer_data, or refer_data itself, as the affine matrix + affine = getattr(refer_data, "affine", refer_data) + d[key] = self.converter(coords, affine) + return d + + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter.inverse(d[key]) + return d + + RandImageFilterD = RandImageFilterDict = RandImageFilterd ImageFilterD = ImageFilterDict = ImageFilterd IdentityD = IdentityDict = Identityd @@ -1780,3 +1881,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N RandCuCIMD = RandCuCIMDict = RandCuCIMd AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd +ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 7027c07d67..e7e1616e13 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -27,6 +27,7 @@ import monai from monai.config import DtypeLike, IndexSelection from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.data.utils import to_affine_nd from monai.networks.layers import GaussianFilter from monai.networks.utils import meshgrid_ij from monai.transforms.compose import Compose @@ -35,6 +36,7 @@ from monai.transforms.utils_pytorch_numpy_unification import ( any_np_pt, ascontiguousarray, + concatenate, cumsum, isfinite, nonzero, @@ -580,7 +582,8 @@ def weighted_patch_samples( if not v[-1] or not isfinite(v[-1]) or v[-1] < 0: # uniform sampling idx = r_state.randint(0, len(v), size=n_samples) else: - r, *_ = convert_to_dst_type(r_state.random(n_samples), v) + r_samples = r_state.random(n_samples) + r, *_ = convert_to_dst_type(r_samples, v, dtype=r_samples.dtype) idx = searchsorted(v, r * v[-1], right=True) # type: ignore idx, *_ = convert_to_dst_type(idx, v, dtype=torch.int) # type: ignore # compensate 'valid' mode @@ -1861,7 +1864,7 @@ class Fourier: """ @staticmethod - def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor: + def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, as_contiguous: bool = False) -> NdarrayOrTensor: """ Applies fourier transform and shifts the zero-frequency component to the center of the spectrum. Only the spatial dimensions get transformed. @@ -1869,6 +1872,7 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor: Args: x: Image to transform. spatial_dims: Number of spatial dimensions. + as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous. Returns k: K-space data. @@ -1883,10 +1887,12 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor: k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims) else: k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims) - return k + return ascontiguousarray(k) if as_contiguous else k @staticmethod - def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None) -> NdarrayOrTensor: + def inv_shift_fourier( + k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None, as_contiguous: bool = False + ) -> NdarrayOrTensor: """ Applies inverse shift and fourier transform. Only the spatial dimensions are transformed. @@ -1894,6 +1900,7 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None Args: k: K-space data. spatial_dims: Number of spatial dimensions. + as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous. Returns: x: Tensor in image space. @@ -1908,7 +1915,7 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real else: out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real - return out + return ascontiguousarray(out) if as_contiguous else out def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Hashable | None = None) -> int: @@ -2512,6 +2519,7 @@ def distance_transform_edt( block_params=block_params, float64_distances=float64_distances, ) + torch.cuda.synchronize() else: if not has_ndimage: raise RuntimeError("scipy.ndimage required if cupy is not available") @@ -2545,7 +2553,7 @@ def distance_transform_edt( r_vals = [] if return_distances and distances_original is None: - r_vals.append(distances) + r_vals.append(distances_ if use_cp else distances) if return_indices and indices_original is None: r_vals.append(indices) if not r_vals: @@ -2554,5 +2562,26 @@ def distance_transform_edt( return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0] +def apply_affine_to_points(data: torch.Tensor, affine: torch.Tensor, dtype: DtypeLike | torch.dtype | None = None): + """ + apply affine transformation to a set of points. + + Args: + data: input data to apply affine transformation, should be a tensor of shape (C, N, 2 or 3), + where C represents the number of channels and N denotes the number of points. + affine: affine matrix to be applied, should be a tensor of shape (3, 3) or (4, 4). + dtype: output data dtype. + """ + data_: torch.Tensor = convert_to_tensor(data, track_meta=False, dtype=torch.float64) + affine = to_affine_nd(data_.shape[-1], affine) + + homogeneous: torch.Tensor = concatenate((data_, torch.ones((data_.shape[0], data_.shape[1], 1))), axis=2) # type: ignore + transformed_homogeneous = torch.matmul(homogeneous, affine.T) + transformed_coordinates = transformed_homogeneous[:, :, :-1] + out, *_ = convert_to_dst_type(transformed_coordinates, data, dtype=dtype) + + return out + + if __name__ == "__main__": print_transform_backends() diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 03fa1ceed1..916c1a6c70 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -11,8 +11,6 @@ from __future__ import annotations -# have to explicitly bring these in here to resolve circular import issues -from .aliases import alias, resolve_name from .component_store import ComponentStore from .decorators import MethodReplacer, RestartGenerator from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default @@ -40,6 +38,7 @@ GridSamplePadMode, HoVerNetBranch, HoVerNetMode, + IgniteInfo, InterpolateMode, JITMetadataKeys, LazyAttr, @@ -109,7 +108,6 @@ allow_missing_reference, damerau_levenshtein_distance, exact_version, - export, get_full_type_name, get_package_version, get_torch_version_tuple, @@ -148,7 +146,10 @@ dtype_numpy_to_torch, dtype_torch_to_numpy, get_dtype, + get_dtype_string, get_equivalent_dtype, get_numpy_dtype_from_string, get_torch_dtype_from_string, ) + +# have to explicitly bring these in here to resolve circular import issues diff --git a/monai/utils/aliases.py b/monai/utils/aliases.py deleted file mode 100644 index 2974eec2eb..0000000000 --- a/monai/utils/aliases.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This module is written for configurable workflow, not currently in use. -""" - -from __future__ import annotations - -import importlib -import inspect -import sys -import threading - -alias_lock = threading.RLock() -GlobalAliases = {} - -__all__ = ["alias", "resolve_name"] - - -def alias(*names): - """ - Stores the decorated function or class in the global aliases table under the given names and as the `__aliases__` - member of the decorated object. This new member will contain all alias names declared for that object. - """ - - def _outer(obj): - for n in names: - with alias_lock: - GlobalAliases[n] = obj - - # set the member list __aliases__ to contain the alias names defined by the decorator for `obj` - obj.__aliases__ = getattr(obj, "__aliases__", ()) + tuple(names) - - return obj - - return _outer - - -def resolve_name(name): - """ - Search for the declaration (function or class) with the given name. This will first search the list of aliases to - see if it was declared with this aliased name, then search treating `name` as a fully qualified name, then search - the loaded modules for one having a declaration with the given name. If no declaration is found, raise ValueError. - - Raises: - ValueError: When the module is not found. - ValueError: When the module does not have the specified member. - ValueError: When multiple modules with the declaration name are found. - ValueError: When no module with the specified member is found. - - """ - # attempt to resolve an alias - with alias_lock: - obj = GlobalAliases.get(name) - - if name in GlobalAliases and obj is None: - raise AssertionError - - # attempt to resolve a qualified name - if obj is None and "." in name: - modname, declname = name.rsplit(".", 1) - - try: - mod = importlib.import_module(modname) - obj = getattr(mod, declname, None) - except ModuleNotFoundError as not_found_err: - raise ValueError(f"Module {modname!r} not found.") from not_found_err - - if obj is None: - raise ValueError(f"Module {modname!r} does not have member {declname!r}.") - - # attempt to resolve a simple name - if obj is None: - # Get all modules having the declaration/import, need to check here that getattr returns something which doesn't - # equate to False since in places __getattr__ returns 0 incorrectly: - # https://github.com/tensorflow/tensorboard/blob/a22566561d2b4fea408755a951ac9eaf3a156f8e/ - # tensorboard/compat/tensorflow_stub/pywrap_tensorflow.py#L35 - mods = [m for m in list(sys.modules.values()) if getattr(m, name, None)] - - if len(mods) > 0: # found modules with this declaration or import - if len(mods) > 1: # found multiple modules, need to determine if ambiguous or just multiple imports - foundmods = set(filter(None, {inspect.getmodule(getattr(m, name)) for m in mods})) # resolve imports - - if len(foundmods) > 1: # found multiple declarations with the same name - modnames = [m.__name__ for m in foundmods] - msg = f"Multiple modules ({modnames!r}) with declaration name {name!r} found, resolution is ambiguous." - raise ValueError(msg) - mods = list(foundmods) - - obj = getattr(mods[0], name) - - if obj is None: - raise ValueError(f"No module with member {name!r} found.") - - return obj diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 2418b43591..c7ff988027 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -24,7 +24,7 @@ import torch import torch.distributed as dist -from monai.config import IgniteInfo +from monai.utils.enums import IgniteInfo from monai.utils.module import min_version, optional_import idist, has_ignite = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") diff --git a/monai/utils/enums.py b/monai/utils/enums.py index eba1be18ed..1fbf3ffa05 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -15,8 +15,6 @@ from enum import Enum from typing import TYPE_CHECKING -from monai.config import IgniteInfo -from monai.utils import deprecated from monai.utils.module import min_version, optional_import __all__ = [ @@ -56,13 +54,13 @@ "DataStatsKeys", "ImageStatsKeys", "LabelStatsKeys", - "AlgoEnsembleKeys", "HoVerNetMode", "HoVerNetBranch", "LazyAttr", "BundleProperty", "BundlePropertyConfig", "AlgoKeys", + "IgniteInfo", ] @@ -91,14 +89,6 @@ def __repr__(self): return self.value -if TYPE_CHECKING: - from ignite.engine import EventEnum -else: - EventEnum, _ = optional_import( - "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base" - ) - - class NumpyPadMode(StrEnum): """ See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html @@ -615,17 +605,6 @@ class LabelStatsKeys(StrEnum): LABEL_NCOMP = "ncomponents" -@deprecated(since="1.2", removed="1.4", msg_suffix="please use `AlgoKeys` instead.") -class AlgoEnsembleKeys(StrEnum): - """ - Default keys for Mixed Ensemble - """ - - ID = "identifier" - ALGO = "infer_algo" - SCORE = "best_metric" - - class HoVerNetMode(StrEnum): """ Modes for HoVerNet model: @@ -730,6 +709,35 @@ class AdversarialKeys(StrEnum): DISCRIMINATOR_LOSS = "discriminator_loss" +class OrderingType(StrEnum): + RASTER_SCAN = "raster_scan" + S_CURVE = "s_curve" + RANDOM = "random" + + +class OrderingTransformations(StrEnum): + ROTATE_90 = "rotate_90" + TRANSPOSE = "transpose" + REFLECT = "reflect" + + +class IgniteInfo(StrEnum): + """ + Config information of the PyTorch ignite package. + + """ + + OPT_IMPORT_VERSION = "0.4.11" + + +if TYPE_CHECKING: + from ignite.engine import EventEnum +else: + EventEnum, _ = optional_import( + "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base" + ) + + class AdversarialIterationEvents(EventEnum): """ Keys used to define events as used in the AdversarialTrainer. @@ -746,15 +754,3 @@ class AdversarialIterationEvents(EventEnum): DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed" DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed" DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed" - - -class OrderingType(StrEnum): - RASTER_SCAN = "raster_scan" - S_CURVE = "s_curve" - RANDOM = "random" - - -class OrderingTransformations(StrEnum): - ROTATE_90 = "rotate_90" - TRANSPOSE = "transpose" - REFLECT = "reflect" diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index 7dcd0e62cd..b1b43a6767 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -24,7 +24,7 @@ import numpy as np import torch -from monai.config import IgniteInfo +from monai.utils import IgniteInfo from monai.utils.module import min_version, optional_import try: diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 40370ca2c6..6386aae713 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -887,7 +887,7 @@ def run_cmd(cmd_list: list[str], **kwargs: Any) -> subprocess.CompletedProcess: if kwargs.pop("run_cmd_verbose", False): import monai - monai.apps.utils.get_logger("run_cmd").info(f"{cmd_list}") + monai.apps.utils.get_logger("run_cmd").info(f"{cmd_list}") # type: ignore[attr-defined] try: return subprocess.run(cmd_list, **kwargs) except subprocess.CalledProcessError as e: diff --git a/monai/utils/module.py b/monai/utils/module.py index 78087aef84..df5fe873ae 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -43,13 +43,11 @@ "InvalidPyTorchVersionError", "OptionalImportError", "exact_version", - "export", "damerau_levenshtein_distance", "look_up_option", "min_version", "optional_import", "require_pkg", - "load_submodules", "instantiate", "get_full_type_name", "get_package_version", @@ -172,28 +170,6 @@ def damerau_levenshtein_distance(s1: str, s2: str) -> int: return d[string_1_length - 1, string_2_length - 1] -def export(modname): - """ - Make the decorated object a member of the named module. This will also add the object under its aliases if it has - a `__aliases__` member, thus this decorator should be before the `alias` decorator to pick up those names. Alias - names which conflict with package names or existing members will be ignored. - """ - - def _inner(obj): - mod = import_module(modname) - if not hasattr(mod, obj.__name__): - setattr(mod, obj.__name__, obj) - - # add the aliases for `obj` to the target module - for alias in getattr(obj, "__aliases__", ()): - if not hasattr(mod, alias): - setattr(mod, alias, obj) - - return obj - - return _inner - - def load_submodules( basemod: ModuleType, load_all: bool = True, exclude_pattern: str = "(.*[tT]est.*)|(_.*)" ) -> tuple[list[ModuleType], list[str]]: diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index e4f97fc4a6..420e935b33 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -33,6 +33,7 @@ "get_equivalent_dtype", "convert_data_type", "get_dtype", + "get_dtype_string", "convert_to_cupy", "convert_to_numpy", "convert_to_tensor", @@ -102,6 +103,13 @@ def get_dtype(data: Any) -> DtypeLike | torch.dtype: return type(data) +def get_dtype_string(dtype: DtypeLike | torch.dtype) -> str: + """Get a string representation of the dtype.""" + if isinstance(dtype, torch.dtype): + return str(dtype)[6:] + return str(dtype)[3:] + + def convert_to_tensor( data: Any, dtype: DtypeLike | torch.dtype = None, diff --git a/pyproject.toml b/pyproject.toml index 53ca608d20..c2ab92a43d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ requires = [ "setuptools", "torch>=1.9", "ninja", + "packaging" ] [tool.black] diff --git a/requirements-dev.txt b/requirements-dev.txt index 9aad0804e6..6d0ccd378a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,3 +59,5 @@ nvidia-ml-py huggingface_hub pyamg>=5.0.0 git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 +onnx_graphsurgeon +polygraphy diff --git a/setup.cfg b/setup.cfg index 1ce4a3f34c..694dc969d9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,6 +40,7 @@ python_requires = >= 3.9 setup_requires = torch ninja + packaging install_requires = torch>=1.9 numpy>=1.24,<2.0 @@ -160,6 +161,9 @@ lpips = lpips==0.1.4 pynvml = nvidia-ml-py +polygraphy = + polygraphy + # # workaround https://github.com/Project-MONAI/MONAI/issues/5882 # MetricsReloaded = # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded diff --git a/tests/min_tests.py b/tests/min_tests.py index f80d06f5d3..f39d3f9843 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -186,6 +186,7 @@ def run_testsuit(): "test_torchvisiond", "test_transchex", "test_transformerblock", + "test_trt_compile", "test_unetr", "test_unetr_block", "test_vit", @@ -211,6 +212,7 @@ def run_testsuit(): "test_ultrasound_confidence_map_transform", "test_vista3d_utils", "test_vista3d_transforms", + "test_matshow3d", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_apply_transform_to_points.py b/tests/test_apply_transform_to_points.py new file mode 100644 index 0000000000..0c16603996 --- /dev/null +++ b/tests/test_apply_transform_to_points.py @@ -0,0 +1,81 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms.utility.array import ApplyTransformToPoints +from monai.utils import set_determinism + +set_determinism(seed=0) + +DATA_2D = torch.rand(1, 64, 64) +DATA_3D = torch.rand(1, 64, 64, 64) +POINT_2D_WORLD = torch.tensor([[[2, 2], [2, 4], [4, 6]]]) +POINT_2D_IMAGE = torch.tensor([[[1, 1], [1, 2], [2, 3]]]) +POINT_2D_IMAGE_RAS = torch.tensor([[[-1, -1], [-1, -2], [-2, -3]]]) +POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]) +POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]]) +POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]]) +AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) +AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]) + +TEST_CASES = [ + [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], + [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, True, POINT_2D_IMAGE_RAS], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], + [ + MetaTensor(DATA_3D, affine=AFFINE_2), + MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2), + None, + False, + False, + POINT_3D_WORLD, + ], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], +] + +TEST_CASES_WRONG = [ + [POINT_2D_WORLD, True, None], + [POINT_2D_WORLD.unsqueeze(0), False, None], + [POINT_3D_WORLD[..., 0:1], False, None], + [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])], +] + + +class TestCoordinateTransform(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_transform_coordinates(self, image, points, affine, invert_affine, affine_lps_to_ras, expected_output): + transform = ApplyTransformToPoints( + dtype=torch.int64, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras + ) + affine = image.affine if image is not None else None + output = transform(points, affine) + self.assertTrue(torch.allclose(output, expected_output)) + invert_out = transform.inverse(output) + self.assertTrue(torch.allclose(invert_out, points)) + + @parameterized.expand(TEST_CASES_WRONG) + def test_wrong_input(self, input, invert_affine, affine): + transform = ApplyTransformToPoints(dtype=torch.int64, invert_affine=invert_affine) + with self.assertRaises(ValueError): + transform(input, affine) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_apply_transform_to_pointsd.py b/tests/test_apply_transform_to_pointsd.py new file mode 100644 index 0000000000..978113931c --- /dev/null +++ b/tests/test_apply_transform_to_pointsd.py @@ -0,0 +1,185 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.transforms.utility.dictionary import ApplyTransformToPointsd +from monai.utils import set_determinism + +set_determinism(seed=0) + +DATA_2D = torch.rand(1, 64, 64) +DATA_3D = torch.rand(1, 64, 64, 64) +POINT_2D_WORLD = torch.tensor([[[2, 2], [2, 4], [4, 6]]]) +POINT_2D_IMAGE = torch.tensor([[[1, 1], [1, 2], [2, 3]]]) +POINT_2D_IMAGE_RAS = torch.tensor([[[-1, -1], [-1, -2], [-2, -3]]]) +POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]) +POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]]) +POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]]) +AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) +AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]) + +TEST_CASES = [ + [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], # use image affine + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], # use point affine + [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], # use input affine + [None, POINT_2D_WORLD, AFFINE_1, True, False, POINT_2D_IMAGE], # use input affine + [ + MetaTensor(DATA_2D, affine=AFFINE_1), + POINT_2D_WORLD, + None, + True, + True, + POINT_2D_IMAGE_RAS, + ], # test affine_lps_to_ras + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], + ["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], # use refer_data itself + [ + MetaTensor(DATA_3D, affine=AFFINE_2), + MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2), + None, + False, + False, + POINT_3D_WORLD, + ], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], + [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS], +] +TEST_CASES_SEQUENCE = [ + [ + (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), + [POINT_2D_WORLD, POINT_3D_WORLD], + None, + True, + False, + ["image_1", "image_2"], + [POINT_2D_IMAGE, POINT_3D_IMAGE], + ], # use image affine + [ + (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), + [POINT_2D_WORLD, POINT_3D_WORLD], + None, + True, + True, + ["image_1", "image_2"], + [POINT_2D_IMAGE_RAS, POINT_3D_IMAGE_RAS], + ], # test affine_lps_to_ras + [ + (None, None), + [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], + None, + False, + False, + None, + [POINT_2D_WORLD, POINT_3D_WORLD], + ], # use point affine + [ + (None, None), + [POINT_2D_WORLD, POINT_2D_WORLD], + AFFINE_1, + True, + False, + None, + [POINT_2D_IMAGE, POINT_2D_IMAGE], + ], # use input affine + [ + (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)), + [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)], + None, + False, + False, + ["image_1", "image_2"], + [POINT_2D_WORLD, POINT_3D_WORLD], + ], +] + +TEST_CASES_WRONG = [ + [POINT_2D_WORLD, True, None, None], + [POINT_2D_WORLD.unsqueeze(0), False, None, None], + [POINT_3D_WORLD[..., 0:1], False, None, None], + [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]]), None], + [POINT_3D_WORLD, False, None, "image"], + [POINT_3D_WORLD, False, None, []], +] + + +class TestCoordinateTransform(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_transform_coordinates(self, image, points, affine, invert_affine, affine_lps_to_ras, expected_output): + data = { + "image": image, + "point": points, + "affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]), + } + refer_keys = "image" if (image is not None and image != "affine") else image + transform = ApplyTransformToPointsd( + keys="point", + refer_keys=refer_keys, + dtype=torch.int64, + affine=affine, + invert_affine=invert_affine, + affine_lps_to_ras=affine_lps_to_ras, + ) + output = transform(data) + + self.assertTrue(torch.allclose(output["point"], expected_output)) + invert_out = transform.inverse(output) + self.assertTrue(torch.allclose(invert_out["point"], points)) + + @parameterized.expand(TEST_CASES_SEQUENCE) + def test_transform_coordinates_sequences( + self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output + ): + data = {"image_1": image[0], "image_2": image[1], "point_1": points[0], "point_2": points[1]} + keys = ["point_1", "point_2"] + transform = ApplyTransformToPointsd( + keys=keys, + refer_keys=refer_keys, + dtype=torch.int64, + affine=affine, + invert_affine=invert_affine, + affine_lps_to_ras=affine_lps_to_ras, + ) + output = transform(data) + + self.assertTrue(torch.allclose(output["point_1"], expected_output[0])) + self.assertTrue(torch.allclose(output["point_2"], expected_output[1])) + invert_out = transform.inverse(output) + self.assertTrue(torch.allclose(invert_out["point_1"], points[0])) + + @parameterized.expand(TEST_CASES_WRONG) + def test_wrong_input(self, input, invert_affine, affine, refer_keys): + if refer_keys == []: + with self.assertRaises(ValueError): + ApplyTransformToPointsd( + keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys + ) + else: + transform = ApplyTransformToPointsd( + keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys + ) + data = {"point": input} + if refer_keys == "image": + with self.assertRaises(KeyError): + transform(data) + else: + with self.assertRaises(ValueError): + transform(data) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index 331d228f1e..02a9f40846 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -89,7 +89,7 @@ TEST_CASE_10 = [ ["network.json", "test_output.pt", "test_input.pt", "large_files.yaml"], "test_bundle", - "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle_v0.1.2.zip", + "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle_v0.1.3.zip", {"model.pt": "27952767e2e154e3b0ee65defc5aed38", "model.ts": "97746870fe591f69ac09827175b00675"}, ] diff --git a/tests/test_compute_generalized_dice.py b/tests/test_compute_generalized_dice.py index e04444e988..985a01e993 100644 --- a/tests/test_compute_generalized_dice.py +++ b/tests/test_compute_generalized_dice.py @@ -22,17 +22,17 @@ _device = "cuda:0" if torch.cuda.is_available() else "cpu" # keep background -TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1) +TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1) with compute_generalized_dice { "y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device), "y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device), "include_background": True, }, - [0.8], + [[0.8]], ] # remove background -TEST_CASE_2 = [ # y (2, 1, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background) +TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background) with GeneralizedDiceScore { "y_pred": torch.tensor( [ @@ -47,32 +47,32 @@ ] ), "include_background": False, + "reduction": "mean_batch", }, - [0.1667, 0.6667], + [0.583333, 0.333333], ] -# should return 0 for both cases -TEST_CASE_3 = [ +TEST_CASE_3 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (1) with GeneralizedDiceScore { "y_pred": torch.tensor( [ - [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]], - [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]], + [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], + [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], ] ), "y": torch.tensor( [ [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], - [[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], ] ), "include_background": True, + "reduction": "mean", }, - [0.0, 0.0], + [0.5454], ] -TEST_CASE_4 = [ - {"include_background": True, "reduction": "mean_batch"}, +TEST_CASE_4 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (1) with GeneralizedDiceScore { "y_pred": torch.tensor( [ @@ -83,15 +83,36 @@ "y": torch.tensor( [ [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], - [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], + [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], ] ), + "include_background": True, + "reduction": "sum", }, - [0.5455], + [1.045455], +] + +TEST_CASE_5 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice + {"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, + [[1.0000, 1.0000], [1.0000, 1.0000]], ] -TEST_CASE_5 = [ - {"include_background": True, "reduction": "sum_batch"}, +TEST_CASE_6 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice + {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, + [[0.0000, 0.0000], [0.0000, 0.0000]], +] + +TEST_CASE_7 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice + {"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, + [[0.0000, 0.0000], [0.0000, 0.0000]], +] + +TEST_CASE_8 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice + {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, + [[1.0000, 1.0000], [1.0000, 1.0000]], +] + +TEST_CASE_9 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2) with GeneralizedDiceScore { "y_pred": torch.tensor( [ @@ -102,61 +123,118 @@ "y": torch.tensor( [ [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], - [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], ] ), + "include_background": True, + "reduction": "mean_channel", }, - 1.0455, + [0.545455, 0.545455], ] -TEST_CASE_6 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [1.0000, 1.0000]] -TEST_CASE_7 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [0.0000, 0.0000]] - -TEST_CASE_8 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [0.0000, 0.0000]] +TEST_CASE_10 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 3) with compute_generalized_dice + # and (3) with GeneralizedDiceScore "mean_batch" + { + "y_pred": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], + [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], + ] + ), + "y": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], + ] + ), + "include_background": True, + }, + [[0.857143, 0.0, 0.0], [0.5, 0.4, 0.666667]], +] -TEST_CASE_9 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [1.0000, 1.0000]] +TEST_CASE_11 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 1) with compute_generalized_dice (summed over classes) + # and (2) with GeneralizedDiceScore "mean_channel" + { + "y_pred": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], + [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], + ] + ), + "y": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], + ] + ), + "include_background": True, + "sum_over_classes": True, + }, + [[0.545455], [0.545455]], +] class TestComputeGeneralizedDiceScore(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) def test_device(self, input_data, _expected_value): + """ + Test if the result tensor is on the same device as the input tensor. + """ result = compute_generalized_dice(**input_data) np.testing.assert_equal(result.device, input_data["y_pred"].device) - # Functional part tests - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]) def test_value(self, input_data, expected_value): + """ + Test if the computed generalized dice score matches the expected value. + """ result = compute_generalized_dice(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) - # Functional part tests - @parameterized.expand([TEST_CASE_3]) - def test_nans(self, input_data, expected_value): - result = compute_generalized_dice(**input_data) - self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value)) - - # Samplewise tests - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_9]) def test_value_class(self, input_data, expected_value): - # same test as for compute_meandice - vals = {} - vals["y_pred"] = input_data.pop("y_pred") - vals["y"] = input_data.pop("y") + """ + Test if the GeneralizedDiceScore class computes the correct values. + """ + y_pred = input_data.pop("y_pred") + y = input_data.pop("y") generalized_dice_score = GeneralizedDiceScore(**input_data) - generalized_dice_score(**vals) - result = generalized_dice_score.aggregate(reduction="none") + generalized_dice_score(y_pred=y_pred, y=y) + result = generalized_dice_score.aggregate() np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) - # Aggregation tests - @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) - def test_nans_class(self, params, input_data, expected_value): - generalized_dice_score = GeneralizedDiceScore(**params) - generalized_dice_score(**input_data) - result = generalized_dice_score.aggregate() + @parameterized.expand([TEST_CASE_10]) + def test_values_compare(self, input_data, expected_value): + """ + Compare the results of compute_generalized_dice function and GeneralizedDiceScore class. + """ + result = compute_generalized_dice(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + y_pred = input_data.pop("y_pred") + y = input_data.pop("y") + generalized_dice_score = GeneralizedDiceScore(**input_data, reduction="mean_batch") + generalized_dice_score(y_pred=y_pred, y=y) + result_class_mean = generalized_dice_score.aggregate() + np.testing.assert_allclose(result_class_mean.cpu().numpy(), np.mean(expected_value, axis=0), atol=1e-4) + + @parameterized.expand([TEST_CASE_11]) + def test_values_compare_sum_over_classes(self, input_data, expected_value): + """ + Compare the results when summing over classes between compute_generalized_dice function and GeneralizedDiceScore class. + """ + result = compute_generalized_dice(**input_data) + np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + + y_pred = input_data.pop("y_pred") + y = input_data.pop("y") + input_data.pop("sum_over_classes") + generalized_dice_score = GeneralizedDiceScore(**input_data, reduction="mean_channel") + generalized_dice_score(y_pred=y_pred, y=y) + result_class_mean = generalized_dice_score.aggregate() + np.testing.assert_allclose(result_class_mean.cpu().numpy(), np.mean(expected_value, axis=1), atol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index cf1edc8f08..2b00c9f9d1 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -125,6 +125,22 @@ def __call__(self, a, b): [0, 4], ] +TEST_CASE_MERGE_JSON = ["""{"key1": [0], "key2": [0] }""", """{"key1": [1], "+key2": [4] }""", "json", [1], [0, 4]] + +TEST_CASE_MERGE_YAML = [ + """ + key1: 0 + key2: [0] + """, + """ + key1: 1 + +key2: [4] + """, + "yaml", + 1, + [0, 4], +] + class TestConfigParser(unittest.TestCase): @@ -357,6 +373,22 @@ def test_parse_json_warn(self, config_string, extension, expected_unique_val, ex self.assertEqual(parser.get_parsed_content("key#unique"), expected_unique_val) self.assertIn(parser.get_parsed_content("key#duplicate"), expected_duplicate_vals) + @parameterized.expand([TEST_CASE_MERGE_JSON, TEST_CASE_MERGE_YAML]) + @skipUnless(has_yaml, "Requires pyyaml") + def test_load_configs( + self, config_string, config_string2, extension, expected_overridden_val, expected_merged_vals + ): + with tempfile.TemporaryDirectory() as tempdir: + config_path1 = Path(tempdir) / f"config1.{extension}" + config_path2 = Path(tempdir) / f"config2.{extension}" + config_path1.write_text(config_string) + config_path2.write_text(config_string2) + + parser = ConfigParser.load_config_files([config_path1, config_path2]) + + self.assertEqual(parser["key1"], expected_overridden_val) + self.assertEqual(parser["key2"], expected_merged_vals) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_convert_box_points.py b/tests/test_convert_box_points.py new file mode 100644 index 0000000000..5e3d7ee645 --- /dev/null +++ b/tests/test_convert_box_points.py @@ -0,0 +1,121 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.data.box_utils import convert_box_to_standard_mode +from monai.transforms.spatial.array import ConvertBoxToPoints, ConvertPointsToBoxes +from tests.utils import assert_allclose + +TEST_CASE_POINTS_2D = [ + [ + torch.tensor([[10, 20, 30, 40], [50, 60, 70, 80]]), + "xyxy", + torch.tensor([[[10, 20], [30, 20], [30, 40], [10, 40]], [[50, 60], [70, 60], [70, 80], [50, 80]]]), + ], + [torch.tensor([[10, 20, 20, 20]]), "ccwh", torch.tensor([[[0, 10], [20, 10], [20, 30], [0, 30]]])], +] +TEST_CASE_POINTS_3D = [ + [ + torch.tensor([[10, 20, 30, 40, 50, 60], [70, 80, 90, 100, 110, 120]]), + "xyzxyz", + torch.tensor( + [ + [ + [10, 20, 30], + [40, 20, 30], + [40, 50, 30], + [10, 50, 30], + [10, 20, 60], + [40, 20, 60], + [40, 50, 60], + [10, 50, 60], + ], + [ + [70, 80, 90], + [100, 80, 90], + [100, 110, 90], + [70, 110, 90], + [70, 80, 120], + [100, 80, 120], + [100, 110, 120], + [70, 110, 120], + ], + ] + ), + ], + [ + torch.tensor([[10, 20, 30, 10, 10, 10]]), + "cccwhd", + torch.tensor( + [ + [ + [5, 15, 25], + [15, 15, 25], + [15, 25, 25], + [5, 25, 25], + [5, 15, 35], + [15, 15, 35], + [15, 25, 35], + [5, 25, 35], + ] + ] + ), + ], + [ + torch.tensor([[10, 20, 30, 40, 50, 60]]), + "xxyyzz", + torch.tensor( + [ + [ + [10, 30, 50], + [20, 30, 50], + [20, 40, 50], + [10, 40, 50], + [10, 30, 60], + [20, 30, 60], + [20, 40, 60], + [10, 40, 60], + ] + ] + ), + ], +] + +TEST_CASES = TEST_CASE_POINTS_2D + TEST_CASE_POINTS_3D + + +class TestConvertBoxToPoints(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_convert_box_to_points(self, boxes, mode, expected_points): + transform = ConvertBoxToPoints(mode=mode) + converted_points = transform(boxes) + assert_allclose(converted_points, expected_points, type_test=False) + + +class TestConvertPointsToBoxes(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_convert_box_to_points(self, boxes, mode, points): + transform = ConvertPointsToBoxes() + converted_boxes = transform(points) + expected_boxes = convert_box_to_standard_mode(boxes, mode) + assert_allclose(converted_boxes, expected_boxes, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 05453b0694..f9b424f8e1 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -23,6 +23,7 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import DataStats TEST_CASE_1 = [ @@ -130,20 +131,55 @@ ] TEST_CASE_8 = [ + { + "prefix": "test data", + "data_type": True, + "data_shape": True, + "value_range": True, + "data_value": True, + "additional_info": np.mean, + "name": "DataStats", + }, np.array([[0, 1], [1, 2]]), "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" "Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", ] +TEST_CASE_9 = [ + np.array([[0, 1], [1, 2]]), + "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" + "Value: [[0 1]\n [1 2]]\n" + "Meta info: '(input is not a MetaTensor)'\n" + "Additional info: 1.0\n", +] + +TEST_CASE_10 = [ + MetaTensor( + torch.tensor([[0, 1], [1, 2]]), + affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64), + meta={"some": "info"}, + ), + "test data statistics:\nType: torch.int64\n" + "Shape: torch.Size([2, 2])\nValue range: (0, 2)\n" + "Value: tensor([[0, 1],\n [1, 2]])\n" + "Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\n" + " [0., 2., 0., 0.],\n" + " [0., 0., 2., 0.],\n" + " [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\n" + "Additional info: 1.0\n", +] + class TestDataStats(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand( + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] + ) def test_value(self, input_param, input_data, expected_print): transform = DataStats(**input_param) _ = transform(input_data) - @parameterized.expand([TEST_CASE_8]) + @parameterized.expand([TEST_CASE_9, TEST_CASE_10]) def test_file(self, input_data, expected_print): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_data_stats.log") @@ -158,6 +194,7 @@ def test_file(self, input_data, expected_print): "data_shape": True, "value_range": True, "data_value": True, + "meta_info": True, "additional_info": np.mean, "name": name, } diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index ef88300c10..a28a938c40 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -21,6 +21,7 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.transforms import DataStatsd TEST_CASE_1 = [ @@ -150,22 +151,70 @@ ] TEST_CASE_9 = [ + { + "keys": "img", + "prefix": "test data", + "data_shape": True, + "value_range": True, + "data_value": True, + "meta_info": False, + "additional_info": np.mean, + "name": "DataStats", + }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" "Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", ] +TEST_CASE_10 = [ + {"img": np.array([[0, 1], [1, 2]])}, + "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n" + "Value: [[0 1]\n [1 2]]\n" + "Meta info: '(input is not a MetaTensor)'\n" + "Additional info: 1.0\n", +] + +TEST_CASE_11 = [ + { + "img": ( + MetaTensor( + torch.tensor([[0, 1], [1, 2]]), + affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64), + meta={"some": "info"}, + ) + ) + }, + "test data statistics:\nType: torch.int64\n" + "Shape: torch.Size([2, 2])\nValue range: (0, 2)\n" + "Value: tensor([[0, 1],\n [1, 2]])\n" + "Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\n" + " [0., 2., 0., 0.],\n" + " [0., 0., 2., 0.],\n" + " [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\n" + "Additional info: 1.0\n", +] + class TestDataStatsd(unittest.TestCase): @parameterized.expand( - [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8] + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + ] ) def test_value(self, input_param, input_data, expected_print): transform = DataStatsd(**input_param) _ = transform(input_data) - @parameterized.expand([TEST_CASE_9]) + @parameterized.expand([TEST_CASE_10, TEST_CASE_11]) def test_file(self, input_data, expected_print): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_stats.log") @@ -180,6 +229,7 @@ def test_file(self, input_data, expected_print): "data_shape": True, "value_range": True, "data_value": True, + "meta_info": True, "additional_info": np.mean, "name": name, } diff --git a/tests/test_fastmri_reader.py b/tests/test_fastmri_reader.py index af2eed7db5..06c3954eae 100644 --- a/tests/test_fastmri_reader.py +++ b/tests/test_fastmri_reader.py @@ -17,7 +17,7 @@ from parameterized import parameterized from monai.apps.reconstruction.fastmri_reader import FastMRIReader -from tests.utils import assert_allclose +from tests.utils import SkipIfNoModule, assert_allclose TEST_CASE1 = [ { @@ -64,6 +64,7 @@ ] +@SkipIfNoModule("h5py") class TestMRIUtils(unittest.TestCase): @parameterized.expand([TEST_CASE1, TEST_CASE2]) diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py index f0a419dcf5..5d2e2aa013 100644 --- a/tests/test_gdsdataset.py +++ b/tests/test_gdsdataset.py @@ -23,7 +23,7 @@ from monai.data import GDSDataset, json_hashing from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform from monai.utils import optional_import -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_no_cuda _, has_cp = optional_import("cupy") nib, has_nib = optional_import("nibabel") @@ -70,9 +70,9 @@ def __call__(self, data): return data +@skip_if_no_cuda @unittest.skipUnless(has_cp, "Requires CuPy library.") -@unittest.skipUnless(has_nib, "Requires nibabel package.") -@unittest.skipUnless(has_kvikio_numpy, "Requires scikit-image library.") +@unittest.skipUnless(has_cp and has_kvikio_numpy, "Requires CuPy and kvikio library.") class TestDataset(unittest.TestCase): def test_cache(self): @@ -131,6 +131,7 @@ def test_dtype(self): self.assertEqual(ds[0].dtype, DTYPES[_dtype]) self.assertEqual(ds1[0].dtype, DTYPES[_dtype]) + @unittest.skipUnless(has_nib, "Requires nibabel package.") @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py index 317eba1b11..4254a73a6b 100644 --- a/tests/test_handler_garbage_collector.py +++ b/tests/test_handler_garbage_collector.py @@ -19,10 +19,9 @@ from ignite.engine import Engine from parameterized import parameterized -from monai.config import IgniteInfo from monai.data import Dataset from monai.handlers import GarbageCollector -from monai.utils import min_version, optional_import +from monai.utils import IgniteInfo, min_version, optional_import Events, has_ignite = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") diff --git a/tests/test_nrrd_reader.py b/tests/test_nrrd_reader.py index 649b9fa94d..5bf958e970 100644 --- a/tests/test_nrrd_reader.py +++ b/tests/test_nrrd_reader.py @@ -40,8 +40,8 @@ "dimension": 4, "space": "left-posterior-superior", "sizes": [3, 4, 4, 1], - "space directions": [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], - "space origin": [0.0, 0.0, 0.0], + "space directions": [[0.7, 0.0, 0.0], [0.0, 0.0, -0.8], [0.0, 0.9, 0.0]], + "space origin": [1.0, 5.0, 20.0], }, ] @@ -110,6 +110,10 @@ def test_read_with_header(self, data_shape, filename, expected_shape, dtype, ref np.testing.assert_allclose(image_array, test_image) self.assertIsInstance(image_header, dict) self.assertTupleEqual(tuple(image_header["spatial_shape"]), expected_shape) + np.testing.assert_allclose( + image_header["affine"], + np.array([[-0.7, 0.0, 0.0, -1.0], [0.0, 0.0, -0.9, -5.0], [0.0, -0.8, 0.0, 20.0], [0.0, 0.0, 0.0, 1.0]]), + ) @parameterized.expand([TEST_CASE_8]) def test_read_with_header_index_order_c(self, data_shape, filename, expected_shape, dtype, reference_header): diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py index 47a8f3bfa2..f509065a56 100644 --- a/tests/test_rand_weighted_crop.py +++ b/tests/test_rand_weighted_crop.py @@ -90,6 +90,21 @@ def get_data(ndim): [[63, 37], [31, 43], [66, 20]], ] ) + im = SEG1_2D + weight_map = np.zeros_like(im, dtype=np.int32) + weight_map[0, 30, 20] = 3 + weight_map[0, 45, 44] = 1 + weight_map[0, 60, 50] = 2 + TESTS.append( + [ + "int w 2d", + dict(spatial_size=(10, 12), num_samples=3), + p(im), + q(weight_map), + (1, 10, 12), + [[60, 50], [30, 20], [45, 44]], + ] + ) im = SEG1_3D weight = np.zeros_like(im) weight[0, 5, 30, 17] = 1.1 @@ -149,6 +164,21 @@ def get_data(ndim): [[32, 24, 40], [32, 24, 40], [32, 24, 40]], ] ) + im = SEG1_3D + weight_map = np.zeros_like(im, dtype=np.int32) + weight_map[0, 6, 22, 19] = 4 + weight_map[0, 8, 40, 31] = 2 + weight_map[0, 13, 20, 24] = 3 + TESTS.append( + [ + "int w 3d", + dict(spatial_size=(8, 10, 12), num_samples=3), + p(im), + q(weight_map), + (1, 8, 10, 12), + [[13, 20, 24], [6, 22, 19], [8, 40, 31]], + ] + ) class TestRandWeightedCrop(CropTest): diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 7c3a684a00..a7390efe72 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import torch from monai.transforms.intensity.array import ScaleIntensityRangePercentiles from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose @@ -34,6 +35,7 @@ def test_scaling(self): scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8) for p in TEST_NDARRAYS: result = scaler(p(img)) + self.assertEqual(result.dtype, torch.uint8) assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4) def test_relative_scaling(self): diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py index 903f9bd2ca..fb8f5dda72 100644 --- a/tests/test_sure_loss.py +++ b/tests/test_sure_loss.py @@ -65,7 +65,7 @@ def operator(x): loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) - self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6) + self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=5) if __name__ == "__main__": diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py new file mode 100644 index 0000000000..2f9db8f0c2 --- /dev/null +++ b/tests/test_trt_compile.py @@ -0,0 +1,140 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import tempfile +import unittest + +import torch +from parameterized import parameterized + +from monai.handlers import TrtHandler +from monai.networks import trt_compile +from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 +from monai.utils import min_version, optional_import +from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows + +trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) +polygraphy, polygraphy_imported = optional_import("polygraphy") +build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") + +TEST_CASE_1 = ["fp32"] +TEST_CASE_2 = ["fp16"] + + +@skip_if_windows +@skip_if_no_cuda +@skip_if_quick +@unittest.skipUnless(trt_imported, "tensorrt is required") +@unittest.skipUnless(polygraphy_imported, "polygraphy is required") +class TestTRTCompile(unittest.TestCase): + + def setUp(self): + self.gpu_device = torch.cuda.current_device() + + def tearDown(self): + current_device = torch.cuda.current_device() + if current_device != self.gpu_device: + torch.cuda.set_device(self.gpu_device) + + def test_handler(self): + from ignite.engine import Engine + + net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data1 = net1.state_dict() + data1["0.weight"] = torch.tensor([0.1]) + data1["1.weight"] = torch.tensor([0.2]) + net1.load_state_dict(data1) + net1.cuda() + + with tempfile.TemporaryDirectory() as tempdir: + engine = Engine(lambda e, b: None) + args = {"method": "torch_trt"} + TrtHandler(net1, tempdir + "/trt_handler", args=args).attach(engine) + engine.run([0] * 8, max_epochs=1) + self.assertIsNotNone(net1._trt_compiler) + self.assertIsNone(net1._trt_compiler.engine) + net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda")) + self.assertIsNotNone(net1._trt_compiler.engine) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_unet_value(self, precision): + model = UNet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=(2, 2, 4, 8, 4), + strides=(2, 2, 2, 2), + num_res_units=2, + norm="batch", + ).cuda() + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + model.eval() + input_example = torch.randn(2, 1, 96, 96, 96).cuda() + output_example = model(input_example) + args: dict = {"builder_optimization_level": 1} + trt_compile( + model, + f"{tmpdir}/test_unet_trt_compile", + args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]}, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @unittest.skipUnless(has_sam, "Requires SAM installation") + def test_cell_sam_wrapper_value(self, precision): + model = cell_sam_wrapper.CellSamWrapper(checkpoint=None).to("cuda") + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + model.eval() + input_example = torch.randn(1, 3, 128, 128).to("cuda") + output_example = model(input_example) + trt_compile( + model, + f"{tmpdir}/test_cell_sam_wrapper_trt_compile", + args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_vista3d(self, precision): + model = vista3d132(in_channels=1).to("cuda") + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + model.eval() + input_example = torch.randn(1, 1, 64, 64, 64).to("cuda") + output_example = model(input_example) + model = trt_compile( + model, + f"{tmpdir}/test_vista3d_trt_compile", + args={"precision": precision, "dynamic_batchsize": [1, 1, 1]}, + submodule=["image_encoder.encoder", "class_head"], + ) + self.assertIsNotNone(model.image_encoder.encoder._trt_compiler) + self.assertIsNotNone(model.class_head._trt_compiler) + trt_output = model.forward(input_example) + # Check that lazy TRT build succeeded + # TODO: set up input_example in such a way that image_encoder.encoder and class_head are called + # and uncomment the asserts below + # self.assertIsNotNone(model.image_encoder.encoder._trt_compiler.engine) + # self.assertIsNotNone(model.class_head._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + +if __name__ == "__main__": + unittest.main() From 2fd93eb84a13934e3aff6ae1ca69e7c79c34d44f Mon Sep 17 00:00:00 2001 From: staydelight Date: Fri, 11 Oct 2024 16:00:09 +0800 Subject: [PATCH 51/52] fix-issue-6366 Signed-off-by: staydelight --- monai/transforms/post/array.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 2e733c4f6c..e9e654eb76 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -679,18 +679,22 @@ def __init__(self, weights: Sequence[float] | NdarrayOrTensor | None = None) -> self.weights = torch.as_tensor(weights, dtype=torch.float) if weights is not None else None def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayOrTensor: - img_ = self.get_stacked_torch(img) - if self.weights is not None: - self.weights = self.weights.to(img_.device) - shape = tuple(self.weights.shape) - for _ in range(img_.ndimension() - self.weights.ndimension()): - shape += (1,) - weights = self.weights.reshape(*shape) - - img_ = img_ * weights / weights.mean(dim=0, keepdim=True) - - out_pt = torch.mean(img_, dim=0) - return self.post_convert(out_pt, img) + out_pt = None + total_weight = 0.0 + + for i, pred in enumerate(img): + pred = torch.as_tensor(pred) + if out_pt is None: + out_pt = torch.zeros_like(pred) + + if self.weights is not None: + weight = self.weights[i].to(pred.device) + + out_pt += pred * weight + total_weight += weight + + out_pt /= total_weight + return post_convert(out_pt, img) class VoteEnsemble(Ensemble, Transform): From 020720d26d30ae3fdfe53bcef4f66df70d16566f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 11 Oct 2024 08:10:44 +0000 Subject: [PATCH 52/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/post/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index e9e654eb76..b860c62e9e 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -681,7 +681,7 @@ def __init__(self, weights: Sequence[float] | NdarrayOrTensor | None = None) -> def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayOrTensor: out_pt = None total_weight = 0.0 - + for i, pred in enumerate(img): pred = torch.as_tensor(pred) if out_pt is None: @@ -689,7 +689,7 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO if self.weights is not None: weight = self.weights[i].to(pred.device) - + out_pt += pred * weight total_weight += weight