diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 020d99af16..98b75cff76 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -480,7 +480,7 @@ def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.max(x, int(dim), **kwargs) # type: ignore - return ret + return ret[0] if isinstance(ret, tuple) else ret def mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: @@ -546,7 +546,7 @@ def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.min(x, int(dim), **kwargs) # type: ignore - return ret + return ret[0] if isinstance(ret, tuple) else ret def std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor: diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index 6e655289e4..90c0401e46 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -17,7 +17,7 @@ import torch from parameterized import parameterized -from monai.transforms.utils_pytorch_numpy_unification import mode, percentile +from monai.transforms.utils_pytorch_numpy_unification import max, min, mode, percentile from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_quick @@ -27,6 +27,13 @@ TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False]) TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True]) +TEST_MIN_MAX = [] +for p in TEST_NDARRAYS: + TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, min, p(1)]) + TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, min, p([3.1, 3])]) + TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, max, p(5)]) + TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, max, p([5.1, 5])]) + class TestPytorchNumpyUnification(unittest.TestCase): @@ -74,6 +81,11 @@ def test_mode(self, array, expected, to_long): res = mode(array, to_long=to_long) assert_allclose(res, expected) + @parameterized.expand(TEST_MIN_MAX) + def test_min_max(self, array, input_params, func, expected): + res = func(array, **input_params) + assert_allclose(res, expected, type_test=False) + if __name__ == "__main__": unittest.main()