Skip to content

Commit cafd7cf

Browse files
Fix logic in how metric states referencing is handled in MetricCollection (#2990)
* fix logic + tests * changelog * fix tests --------- Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 0423a2c commit cafd7cf

File tree

3 files changed

+61
-13
lines changed

3 files changed

+61
-13
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3636

3737
### Fixed
3838

39+
- Fixed logic in how metric states referencing is handled in `MetricCollection` ([#2990](https://github.com/PyTorchLightning/metrics/pull/2990))
40+
41+
3942
- Fixed integration between classwise wrapper and metric tracker ([#3004](https://github.com/PyTorchLightning/metrics/pull/3004))
4043

4144

src/torchmetrics/collections.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,8 @@ def update(self, *args: Any, **kwargs: Any) -> None:
247247
# only update the first member
248248
m0 = getattr(self, cg[0])
249249
m0.update(*args, **m0._filter_kwargs(**kwargs))
250-
if self._state_is_copy:
251-
# If we have deep copied state in between updates, reestablish link
252-
self._compute_groups_create_state_ref()
253-
self._state_is_copy = False
250+
self._state_is_copy = False
251+
self._compute_groups_create_state_ref()
254252
else: # the first update always do per metric to form compute groups
255253
for m in self.values(copy_state=False):
256254
m_kwargs = m._filter_kwargs(**kwargs)
@@ -259,6 +257,7 @@ def update(self, *args: Any, **kwargs: Any) -> None:
259257
if self._enable_compute_groups:
260258
self._merge_compute_groups()
261259
# create reference between states
260+
self._state_is_copy = False
262261
self._compute_groups_create_state_ref()
263262
self._groups_checked = True
264263

@@ -339,7 +338,7 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
339338
of just passed by reference
340339
341340
"""
342-
if not self._state_is_copy and self._groups_checked:
341+
if not self._state_is_copy: # only create reference if not already copied
343342
for cg in self._groups.values():
344343
m0 = getattr(self, cg[0])
345344
for i in range(1, len(cg)):

tests/unittests/bases/test_collections.py

+54-8
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
MultilabelAUROC,
3434
MultilabelAveragePrecision,
3535
)
36+
from torchmetrics.regression import PearsonCorrCoef
3637
from torchmetrics.text import BLEUScore
3738
from torchmetrics.utilities.checks import _allclose_recursive
3839
from unittests._helpers import seed_all
@@ -328,30 +329,35 @@ def compute(self):
328329
"metrics, expected, preds, target",
329330
[
330331
# single metric forms its own compute group
331-
(MulticlassAccuracy(num_classes=3), {0: ["MulticlassAccuracy"]}, _mc_preds, _mc_target),
332+
pytest.param(
333+
MulticlassAccuracy(num_classes=3), {0: ["MulticlassAccuracy"]}, _mc_preds, _mc_target, id="single_metric"
334+
),
332335
# two metrics of same class forms a compute group
333-
(
336+
pytest.param(
334337
{"acc0": MulticlassAccuracy(num_classes=3), "acc1": MulticlassAccuracy(num_classes=3)},
335338
{0: ["acc0", "acc1"]},
336339
_mc_preds,
337340
_mc_target,
341+
id="two_metrics_of_same_class",
338342
),
339343
# two metrics from registry forms a compute group
340-
(
344+
pytest.param(
341345
[MulticlassPrecision(num_classes=3), MulticlassRecall(num_classes=3)],
342346
{0: ["MulticlassPrecision", "MulticlassRecall"]},
343347
_mc_preds,
344348
_mc_target,
349+
id="two_metrics_from_registry",
345350
),
346351
# two metrics from different classes gives two compute groups
347-
(
352+
pytest.param(
348353
[MulticlassConfusionMatrix(num_classes=3), MulticlassRecall(num_classes=3)],
349354
{0: ["MulticlassConfusionMatrix"], 1: ["MulticlassRecall"]},
350355
_mc_preds,
351356
_mc_target,
357+
id="two_metrics_from_different_classes",
352358
),
353359
# multi group multi metric
354-
(
360+
pytest.param(
355361
[
356362
MulticlassConfusionMatrix(num_classes=3),
357363
MulticlassCohenKappa(num_classes=3),
@@ -361,9 +367,10 @@ def compute(self):
361367
{0: ["MulticlassConfusionMatrix", "MulticlassCohenKappa"], 1: ["MulticlassRecall", "MulticlassPrecision"]},
362368
_mc_preds,
363369
_mc_target,
370+
id="multi_group_multi_metric",
364371
),
365372
# Complex example
366-
(
373+
pytest.param(
367374
{
368375
"acc": MulticlassAccuracy(num_classes=3),
369376
"acc2": MulticlassAccuracy(num_classes=3),
@@ -375,19 +382,21 @@ def compute(self):
375382
{0: ["acc", "acc2", "f1", "recall"], 1: ["acc3"], 2: ["confmat"]},
376383
_mc_preds,
377384
_mc_target,
385+
id="complex_example",
378386
),
379387
# With list states
380-
(
388+
pytest.param(
381389
[
382390
MulticlassAUROC(num_classes=3, average="macro"),
383391
MulticlassAveragePrecision(num_classes=3, average="macro"),
384392
],
385393
{0: ["MulticlassAUROC", "MulticlassAveragePrecision"]},
386394
_mc_preds,
387395
_mc_target,
396+
id="with_list_states",
388397
),
389398
# Nested collections
390-
(
399+
pytest.param(
391400
[
392401
MetricCollection(
393402
MultilabelAUROC(num_labels=3, average="micro"),
@@ -410,6 +419,7 @@ def compute(self):
410419
},
411420
_ml_preds,
412421
_ml_target,
422+
id="nested_collections",
413423
),
414424
],
415425
)
@@ -796,3 +806,39 @@ def test_collection_update():
796806

797807
for k, v in expected.items():
798808
torch.testing.assert_close(actual=actual.get(k), expected=v, rtol=1e-4, atol=1e-4)
809+
810+
811+
def test_collection_state_being_re_established_after_copy():
812+
"""Check that shared metrics states when using compute groups are re-established after a copy.
813+
814+
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2896
815+
816+
"""
817+
m1, m2 = PearsonCorrCoef(), PearsonCorrCoef()
818+
m12 = MetricCollection({"m1": m1, "m2": m2}, compute_groups=True)
819+
x1, y1 = torch.randn(100), torch.randn(100)
820+
m12.update(x1, y1)
821+
assert m12.compute_groups == {0: ["m1", "m2"]}
822+
823+
# Check that the states are pointing to the same location
824+
assert not m12._state_is_copy
825+
assert m12.m1.mean_x.data_ptr() == m12.m2.mean_x.data_ptr(), "States should point to the same location"
826+
827+
# Break the references between the states
828+
_ = m12.items()
829+
assert m12._state_is_copy
830+
assert m12.m1.mean_x.data_ptr() != m12.m2.mean_x.data_ptr(), "States should not point to the same location"
831+
832+
# Update should restore the references between the states
833+
x2, y2 = torch.randn(100), torch.randn(100)
834+
835+
m12.update(x2, y2)
836+
assert not m12._state_is_copy
837+
assert m12.m1.mean_x.data_ptr() == m12.m2.mean_x.data_ptr(), "States should point to the same location"
838+
839+
x3, y3 = torch.randn(100), torch.randn(100)
840+
m12.update(x3, y3)
841+
842+
assert not m12._state_is_copy
843+
assert m12.m1.mean_x.data_ptr() == m12.m2.mean_x.data_ptr(), "States should point to the same location"
844+
assert m12._equal_metric_states(m12.m1, m12.m2)

0 commit comments

Comments
 (0)