From de2a819e82e9c0575a959170d8e534fefe002d08 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 23 Aug 2024 11:16:29 +0800 Subject: [PATCH] Fix AttributeError when using torch.min and max (#8041) Fixes #8040. ### Description Only return values if got a namedtuple when using torch.min and max ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .../transforms/utils_pytorch_numpy_unification.py | 4 ++-- tests/test_utils_pytorch_numpy_unification.py | 14 +++++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 020d99af16..98b75cff76 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -480,7 +480,7 @@ def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.max(x, int(dim), **kwargs) # type: ignore - return ret + return ret[0] if isinstance(ret, tuple) else ret def mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: @@ -546,7 +546,7 @@ def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.min(x, int(dim), **kwargs) # type: ignore - return ret + return ret[0] if isinstance(ret, tuple) else ret def std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor: diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index 6e655289e4..90c0401e46 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -17,7 +17,7 @@ import torch from parameterized import parameterized -from monai.transforms.utils_pytorch_numpy_unification import mode, percentile +from monai.transforms.utils_pytorch_numpy_unification import max, min, mode, percentile from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_quick @@ -27,6 +27,13 @@ TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False]) TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True]) +TEST_MIN_MAX = [] +for p in TEST_NDARRAYS: + TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, min, p(1)]) + TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, min, p([3.1, 3])]) + TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, max, p(5)]) + TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, max, p([5.1, 5])]) + class TestPytorchNumpyUnification(unittest.TestCase): @@ -74,6 +81,11 @@ def test_mode(self, array, expected, to_long): res = mode(array, to_long=to_long) assert_allclose(res, expected) + @parameterized.expand(TEST_MIN_MAX) + def test_min_max(self, array, input_params, func, expected): + res = func(array, **input_params) + assert_allclose(res, expected, type_test=False) + if __name__ == "__main__": unittest.main()