diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2322f2123f..5dfbcb0e91 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1562,17 +1562,22 @@ def __init__(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | self.filter_size = filter_size self.additional_args_for_filter = kwargs - def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> NdarrayOrTensor: + def __call__( + self, img: NdarrayOrTensor, meta_dict: dict | None = None, applied_operations: list | None = None + ) -> NdarrayOrTensor: """ Args: img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]] meta_dict: An optional dictionary with metadata + applied_operations: An optional list of operations that have been applied to the data Returns: A MetaTensor with the same shape as `img` and identical metadata """ if isinstance(img, MetaTensor): meta_dict = img.meta + applied_operations = img.applied_operations + img_, prev_type, device = convert_data_type(img, torch.Tensor) ndim = img_.ndim - 1 # assumes channel first format @@ -1582,8 +1587,8 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> Ndarr self.filter = ApplyFilter(self.filter) img_ = self._apply_filter(img_) - if meta_dict: - img_ = MetaTensor(img_, meta=meta_dict) + if meta_dict is not None or applied_operations is not None: + img_ = MetaTensor(img_, meta=meta_dict, applied_operations=applied_operations) else: img_, *_ = convert_data_type(img_, prev_type, device) return img_ diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index 841a5d5cd5..985ea95e79 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -17,6 +17,7 @@ import torch from parameterized import parameterized +from monai.data.meta_tensor import MetaTensor from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd @@ -115,6 +116,21 @@ def test_call_3d(self, filter_name): out_tensor = filter(SAMPLE_IMAGE_3D) self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:]) + def test_pass_applied_operations(self): + "Test that applied operations are passed through" + applied_operations = ["op1", "op2"] + image = MetaTensor(SAMPLE_IMAGE_2D, applied_operations=applied_operations) + filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(image) + self.assertEqual(out_tensor.applied_operations, applied_operations) + + def test_pass_empty_metadata_dict(self): + "Test that applied operations are passed through" + image = MetaTensor(SAMPLE_IMAGE_2D, meta={}) + filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS) + out_tensor = filter(image) + self.assertTrue(isinstance(out_tensor, MetaTensor)) + class TestImageFilterDict(unittest.TestCase): @parameterized.expand(SUPPORTED_FILTERS)