diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 1b0b3da5a..103ed4f81 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -28,23 +28,23 @@ def pool_masked( - source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None + source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean" ) -> Tensor: - """Reduce image feature maps (B, C, H, W) or (C, H, W) according to an integer - index given by `mask` (B, H, W) or (H, W). + """Reduce image feature maps :math:`(B, C, H, W)` or :math:`(C, H, W)` according to + an integer index given by `mask` :math:`(B, H, W)` or :math:`(H, W)`. Args: - source: Float tensor of shape (B, C, H, W) or (C, H, W) to be reduced. - mask: Integer tensor of shape (B, H, W) or (H, W) containing the integer indices. + source: Float tensor of shape :math:`(B, C, H, W)` or :math:`(C, H, W)` to be + reduced. + mask: Integer tensor of shape :math:`(B, H, W)` or :math:`(H, W)` containing the + integer indices. + num_cls: The number of classes in the possible masks. reduce: The reduction operation to be applied, one of 'prod', 'mean', 'amax' or 'amin'. Defaults to 'mean'. - num_cls: The number of classes in the possible masks. If None, the number of classes - is inferred from the unique elements in `mask`. This is useful when not all - classes are present in the mask. Returns: - A tensor of shape (B, C, N) or (C, N) where N is the number of unique elements - in `mask` or `num_cls` if specified. + A tensor of shape :math:`(B, C, N)` or :math:`(C, N)` where :math:`N` + corresponds to `num_cls`. """ if source.dim() == 3: return _mask_reduce(source, mask, reduce, num_cls) @@ -55,29 +55,26 @@ def pool_masked( def _mask_reduce( - source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None + source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean" ) -> Tensor: output = _mask_reduce_batched( - source.unsqueeze(0), mask.unsqueeze(0), num_cls=num_cls + source.unsqueeze(0), mask.unsqueeze(0), num_cls=num_cls, reduce=reduce ) return output.squeeze(0) def _mask_reduce_batched( - source: Tensor, mask: Tensor, num_cls: Optional[int] = None + source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean" ) -> Tensor: b, c, h, w = source.shape - if num_cls is None: - cls = mask.unique(sorted=True) - else: - cls = torch.arange(num_cls, device=mask.device) + cls = torch.arange(num_cls, device=mask.device) num_cls = cls.size(0) # create output tensor output = source.new_zeros((b, c, num_cls)) # (B C N) mask = mask.unsqueeze(1).expand(-1, c, -1, -1).view(b, c, -1) # (B C HW) source = source.view(b, c, -1) # (B C HW) output.scatter_reduce_( - dim=2, index=mask, src=source, reduce="mean", include_self=False + dim=2, index=mask, src=source, reduce=reduce, include_self=False ) # (B C N) # scatter_reduce_ produces NaNs if the count is zero output = torch.nan_to_num(output, nan=0.0) diff --git a/tests/models/test_ModelUtils.py b/tests/models/test_ModelUtils.py index 0c2b292b4..2bde9f4fd 100644 --- a/tests/models/test_ModelUtils.py +++ b/tests/models/test_ModelUtils.py @@ -104,15 +104,6 @@ def test_masked_pooling_manual( assert out_manual.shape == (1, 3, 2) assert (out_manual == expected_result2[:, :2]).all() - def test_masked_pooling_auto( - self, feature_map2: Tensor, mask2: Tensor, expected_result2: Tensor - ) -> None: - out_auto = pool_masked( - feature_map2.unsqueeze(0), mask2.unsqueeze(0), num_cls=None - ) - assert out_auto.shape == (1, 3, 2) - assert (out_auto == expected_result2[:, :2]).all() - # Type ignore because untyped decorator makes function untyped. @pytest.mark.parametrize( "feature_map, mask, expected_result",