Skip to content

Commit b277021

Browse files
committed
fix logic + tests
1 parent 4f89985 commit b277021

File tree

2 files changed

+47
-14
lines changed

2 files changed

+47
-14
lines changed

src/torchmetrics/collections.py

+10-14
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,7 @@ def update(self, *args: Any, **kwargs: Any) -> None:
247247
# only update the first member
248248
m0 = getattr(self, cg[0])
249249
m0.update(*args, **m0._filter_kwargs(**kwargs))
250-
if self._state_is_copy:
251-
# If we have deep copied state in between updates, reestablish link
252-
self._compute_groups_create_state_ref()
253-
self._state_is_copy = False
250+
self._compute_groups_create_state_ref()
254251
else: # the first update always do per metric to form compute groups
255252
for m in self.values(copy_state=False):
256253
m_kwargs = m._filter_kwargs(**kwargs)
@@ -339,16 +336,15 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
339336
of just passed by reference
340337
341338
"""
342-
if not self._state_is_copy and self._groups_checked:
343-
for cg in self._groups.values():
344-
m0 = getattr(self, cg[0])
345-
for i in range(1, len(cg)):
346-
mi = getattr(self, cg[i])
347-
for state in m0._defaults:
348-
m0_state = getattr(m0, state)
349-
# Determine if we just should set a reference or a full copy
350-
setattr(mi, state, deepcopy(m0_state) if copy else m0_state)
351-
mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count
339+
for cg in self._groups.values():
340+
m0 = getattr(self, cg[0])
341+
for i in range(1, len(cg)):
342+
mi = getattr(self, cg[i])
343+
for state in m0._defaults:
344+
m0_state = getattr(m0, state)
345+
# Determine if we just should set a reference or a full copy
346+
setattr(mi, state, deepcopy(m0_state) if copy else m0_state)
347+
mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count
352348
self._state_is_copy = copy
353349

354350
def compute(self) -> dict[str, Any]:

tests/unittests/bases/test_collections.py

+37
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
MultilabelAUROC,
3434
MultilabelAveragePrecision,
3535
)
36+
from torchmetrics.regression import PearsonCorrCoef
3637
from torchmetrics.text import BLEUScore
3738
from torchmetrics.utilities.checks import _allclose_recursive
3839
from unittests._helpers import seed_all
@@ -796,3 +797,39 @@ def test_collection_update():
796797

797798
for k, v in expected.items():
798799
torch.testing.assert_close(actual=actual.get(k), expected=v, rtol=1e-4, atol=1e-4)
800+
801+
802+
def test_collection_state_being_re_established_after_copy():
803+
"""Check that shared metrics states when using compute groups are re-established after a copy.
804+
805+
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2896
806+
807+
"""
808+
m1, m2 = PearsonCorrCoef(), PearsonCorrCoef()
809+
m12 = MetricCollection({"m1": m1, "m2": m2}, compute_groups=True)
810+
x1, y1 = torch.randn(100), torch.randn(100)
811+
m12.update(x1, y1)
812+
assert m12.compute_groups == {0: ["m1", "m2"]}
813+
814+
# Check that the states are pointing to the same location
815+
assert not m12._state_is_copy
816+
assert m12.m1.mean_x.data_ptr() == m12.m2.mean_x.data_ptr(), "States should point to the same location"
817+
818+
# Break the references between the states
819+
_ = m12.items()
820+
assert m12._state_is_copy
821+
assert m12.m1.mean_x.data_ptr() != m12.m2.mean_x.data_ptr(), "States should not point to the same location"
822+
823+
# Update should restore the references between the states
824+
x2, y2 = torch.randn(100), torch.randn(100)
825+
826+
m12.update(x2, y2)
827+
assert not m12._state_is_copy
828+
assert m12.m1.mean_x.data_ptr() == m12.m2.mean_x.data_ptr(), "States should point to the same location"
829+
830+
x3, y3 = torch.randn(100), torch.randn(100)
831+
m12.update(x3, y3)
832+
833+
assert not m12._state_is_copy
834+
assert m12.m1.mean_x.data_ptr() == m12.m2.mean_x.data_ptr(), "States should point to the same location"
835+
assert m12._equal_metric_states(m12.m1, m12.m2)

0 commit comments

Comments
 (0)