Skip to content

Commit

Permalink
Fix AttributeError when using torch.min and max (#8041)
Browse files Browse the repository at this point in the history
Fixes #8040.

### Description

Only return values if got a namedtuple when using torch.min and max

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
KumoLiu authored Aug 23, 2024
1 parent cea80a6 commit de2a819
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
4 changes: 2 additions & 2 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe
else:
ret = torch.max(x, int(dim), **kwargs) # type: ignore

return ret
return ret[0] if isinstance(ret, tuple) else ret


def mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor:
Expand Down Expand Up @@ -546,7 +546,7 @@ def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe
else:
ret = torch.min(x, int(dim), **kwargs) # type: ignore

return ret
return ret[0] if isinstance(ret, tuple) else ret


def std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor:
Expand Down
14 changes: 13 additions & 1 deletion tests/test_utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from parameterized import parameterized

from monai.transforms.utils_pytorch_numpy_unification import mode, percentile
from monai.transforms.utils_pytorch_numpy_unification import max, min, mode, percentile
from monai.utils import set_determinism
from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_quick

Expand All @@ -27,6 +27,13 @@
TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False])
TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True])

TEST_MIN_MAX = []
for p in TEST_NDARRAYS:
TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, min, p(1)])
TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, min, p([3.1, 3])])
TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, max, p(5)])
TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, max, p([5.1, 5])])


class TestPytorchNumpyUnification(unittest.TestCase):

Expand Down Expand Up @@ -74,6 +81,11 @@ def test_mode(self, array, expected, to_long):
res = mode(array, to_long=to_long)
assert_allclose(res, expected)

@parameterized.expand(TEST_MIN_MAX)
def test_min_max(self, array, input_params, func, expected):
res = func(array, **input_params)
assert_allclose(res, expected, type_test=False)


if __name__ == "__main__":
unittest.main()

0 comments on commit de2a819

Please sign in to comment.