Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lionel lig 5887 make arguments to masked pooling non optional #1776

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
22d8058
allow access to v2 transforms availability from anywhere
liopeer Nov 11, 2024
c963024
import unnecessary after exception
liopeer Nov 11, 2024
988a718
reformat
liopeer Nov 12, 2024
4b23bb7
add implementation of AddGridTransform
liopeer Nov 12, 2024
f4e74ab
make AddGridTransform importable
liopeer Nov 12, 2024
a3b7e15
add tests for AddGridTransform
liopeer Nov 12, 2024
ff7a14b
reformat
liopeer Nov 12, 2024
a6fb889
enhance docstring
liopeer Nov 12, 2024
59892a5
fix typing issues
liopeer Nov 12, 2024
edd1755
Merge branch 'master' into lionel-lig-5625-add-addgridtransform2
liopeer Nov 12, 2024
2e89959
fix import when transforms.v2 not available
liopeer Nov 12, 2024
1caaed7
add additional type ignore for 3.7 compatibility
liopeer Nov 12, 2024
7ea7675
reformat
liopeer Nov 12, 2024
357c4dc
add transform to docs
liopeer Nov 12, 2024
c91fdb8
change header
liopeer Nov 12, 2024
af3a02b
add explanation on data structures
liopeer Nov 12, 2024
038ec59
use kw args
liopeer Nov 12, 2024
d64264f
add assertion for mask dimension to be geq 2
liopeer Nov 12, 2024
2cd76cb
remove unnecessary fixtures
liopeer Nov 15, 2024
ccaa553
make argument order consistent
liopeer Nov 15, 2024
7ccf47f
add DetCon SingleView and MultiView transforms
liopeer Nov 18, 2024
8622953
add MultiViewTransform BaseClass for v2 transforms
liopeer Nov 18, 2024
0afb15e
add tests for DetCon transform
liopeer Nov 18, 2024
7b60c16
export all newly added transforms
liopeer Nov 18, 2024
3dee483
add DetCon single view and DetCon multi view transforms
liopeer Nov 18, 2024
55cf731
add torchvision transforms v2 compatible MultiViewTransforms
liopeer Nov 18, 2024
d341d59
make newly added transforms public
liopeer Nov 18, 2024
fa5d635
remove unnecessary fixtures
liopeer Nov 18, 2024
2aa8cdc
add tests for DetCon transform
liopeer Nov 18, 2024
f65e902
Merge branch 'master' into lionel-lig-5626-add-detcontransform
liopeer Nov 18, 2024
ad608c5
merge
liopeer Nov 18, 2024
a98dec5
remove wrongfully added files
liopeer Nov 18, 2024
5436335
add DetCon transform and MultiView transforms for v2 to docs
liopeer Nov 18, 2024
f42feef
fix docs references
liopeer Nov 18, 2024
00b9b60
fix import issues for minimal dependencies
liopeer Nov 18, 2024
d552340
fixing code format
liopeer Nov 18, 2024
3789630
add test for multiviewtransformv2
liopeer Nov 18, 2024
ac3639a
fix testing of multiview
liopeer Nov 18, 2024
56edcd4
use singular AddGridTransforms
liopeer Nov 19, 2024
6b477c3
consistent naming to DetConS
liopeer Nov 19, 2024
f51e834
adjust docstring reference numbering
liopeer Nov 19, 2024
8286ee3
name refactoring
liopeer Nov 19, 2024
cd37874
Merge branch 'master' into lionel-lig-5628-add-detconloss
liopeer Nov 21, 2024
02f5c75
Merge branch 'master' into lionel-lig-5628-add-detconloss
liopeer Nov 21, 2024
5928b57
start detconloss implementation
liopeer Nov 21, 2024
c4fc6ef
Merge branch 'master' into lionel-lig-5628-add-detconloss
liopeer Dec 23, 2024
3828ae3
implement detconloss
liopeer Jan 2, 2025
2a93b0c
test detconloss
liopeer Jan 2, 2025
285b490
make detconloss public
liopeer Jan 2, 2025
7f183a6
add detconloss to docs
liopeer Jan 2, 2025
bf080e5
Merge branch 'master' into lionel-lig-5628-add-detconloss
liopeer Jan 2, 2025
2f78124
Update lightly/loss/detcon_loss.py
liopeer Jan 3, 2025
18abcba
initial small fixes
liopeer Jan 3, 2025
a954bb5
add comments and avoid 0 division
liopeer Jan 3, 2025
aa07b86
remove labels_ext
liopeer Jan 3, 2025
32702a5
remove labels_ext
liopeer Jan 3, 2025
10579d4
revert normalization; some formatting
liopeer Jan 6, 2025
b8b7bf0
complete rewrite of tests
liopeer Jan 6, 2025
49819dc
fix typing issues
liopeer Jan 6, 2025
78e6c85
remove unused imports
liopeer Jan 6, 2025
9e37f70
Update lightly/loss/detcon_loss.py
liopeer Jan 6, 2025
f4d17c0
move to f-strings
liopeer Jan 6, 2025
9cfeeab
create test classes
liopeer Jan 6, 2025
e06f71d
Merge branch 'lionel-lig-5628-add-detconloss' of github.com:lightly-a…
liopeer Jan 6, 2025
d530c5b
squeeze instead of 0 indexing
liopeer Jan 6, 2025
d79ba68
formatting
liopeer Jan 6, 2025
b0bff0f
few more comments on tensor shapes
liopeer Jan 6, 2025
3211b88
additional comments on tensor shapes
liopeer Jan 6, 2025
c92dd4d
fix masked pooling ops
liopeer Jan 10, 2025
b8f0335
Merge branch 'master' into lionel-lig-5887-make-arguments-to-masked-p…
liopeer Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions lightly/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,23 @@


def pool_masked(
source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None
source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean"
) -> Tensor:
"""Reduce image feature maps (B, C, H, W) or (C, H, W) according to an integer
index given by `mask` (B, H, W) or (H, W).
"""Reduce image feature maps :math:`(B, C, H, W)` or :math:`(C, H, W)` according to
an integer index given by `mask` :math:`(B, H, W)` or :math:`(H, W)`.

Args:
source: Float tensor of shape (B, C, H, W) or (C, H, W) to be reduced.
mask: Integer tensor of shape (B, H, W) or (H, W) containing the integer indices.
source: Float tensor of shape :math:`(B, C, H, W)` or :math:`(C, H, W)` to be
reduced.
mask: Integer tensor of shape :math:`(B, H, W)` or :math:`(H, W)` containing the
integer indices.
num_cls: The number of classes in the possible masks.
reduce: The reduction operation to be applied, one of 'prod', 'mean', 'amax' or
'amin'. Defaults to 'mean'.
num_cls: The number of classes in the possible masks. If None, the number of classes
is inferred from the unique elements in `mask`. This is useful when not all
classes are present in the mask.

Returns:
A tensor of shape (B, C, N) or (C, N) where N is the number of unique elements
in `mask` or `num_cls` if specified.
A tensor of shape :math:`(B, C, N)` or :math:`(C, N)` where :math:`N`
corresponds to `num_cls`.
"""
if source.dim() == 3:
return _mask_reduce(source, mask, reduce, num_cls)
Expand All @@ -55,29 +55,26 @@ def pool_masked(


def _mask_reduce(
source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None
source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean"
) -> Tensor:
output = _mask_reduce_batched(
source.unsqueeze(0), mask.unsqueeze(0), num_cls=num_cls
source.unsqueeze(0), mask.unsqueeze(0), num_cls=num_cls, reduce=reduce
)
return output.squeeze(0)


def _mask_reduce_batched(
source: Tensor, mask: Tensor, num_cls: Optional[int] = None
source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean"
) -> Tensor:
b, c, h, w = source.shape
if num_cls is None:
cls = mask.unique(sorted=True)
else:
cls = torch.arange(num_cls, device=mask.device)
cls = torch.arange(num_cls, device=mask.device)
num_cls = cls.size(0)
# create output tensor
output = source.new_zeros((b, c, num_cls)) # (B C N)
mask = mask.unsqueeze(1).expand(-1, c, -1, -1).view(b, c, -1) # (B C HW)
source = source.view(b, c, -1) # (B C HW)
output.scatter_reduce_(
dim=2, index=mask, src=source, reduce="mean", include_self=False
dim=2, index=mask, src=source, reduce=reduce, include_self=False
) # (B C N)
# scatter_reduce_ produces NaNs if the count is zero
output = torch.nan_to_num(output, nan=0.0)
Expand Down
9 changes: 0 additions & 9 deletions tests/models/test_ModelUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,6 @@ def test_masked_pooling_manual(
assert out_manual.shape == (1, 3, 2)
assert (out_manual == expected_result2[:, :2]).all()

def test_masked_pooling_auto(
self, feature_map2: Tensor, mask2: Tensor, expected_result2: Tensor
) -> None:
out_auto = pool_masked(
feature_map2.unsqueeze(0), mask2.unsqueeze(0), num_cls=None
)
assert out_auto.shape == (1, 3, 2)
assert (out_auto == expected_result2[:, :2]).all()

# Type ignore because untyped decorator makes function untyped.
@pytest.mark.parametrize(
"feature_map, mask, expected_result",
Expand Down