Skip to content

Commit a8de07d

Browse files
SkafteNickiBorda
andauthored
Fix corner case in manually specifying compute_groups in MetricCollection (#2979)
* implementation * Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka B <j.borovec+github@gmail.com>
1 parent 8eef0d3 commit a8de07d

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

src/torchmetrics/collections.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from torchmetrics.metric import Metric
2626
from torchmetrics.utilities import rank_zero_warn
27-
from torchmetrics.utilities.data import _flatten_dict, allclose
27+
from torchmetrics.utilities.data import _flatten, _flatten_dict, allclose
2828
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
2929
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val
3030

@@ -90,7 +90,9 @@ class name as key for the output dict.
9090
due to the internal logic of ``forward`` preventing this. Secondly, since we compute groups share metric
9191
states by reference, calling ``.items()``, ``.values()`` etc. on the metric collection will break this
9292
reference and a copy of states are instead returned in this case (reference will be reestablished on the next
93-
call to ``update``).
93+
call to ``update``). Do note that for the time being that if you are manually specifying compute groups in
94+
nested collections, these are not compatible with the compute groups of the parent collection and will be
95+
overridden.
9496
9597
.. important::
9698
Metric collections can be nested at initialization (see last example) but the output of the collection will
@@ -192,7 +194,6 @@ class name of the metric:
192194
"""
193195

194196
_modules: dict[str, Metric] # type: ignore[assignment]
195-
_groups: Dict[int, List[str]]
196197
__jit_unused_properties__: ClassVar[list[str]] = ["metric_state"]
197198

198199
def __init__(
@@ -210,7 +211,7 @@ def __init__(
210211
self._enable_compute_groups = compute_groups
211212
self._groups_checked: bool = False
212213
self._state_is_copy: bool = False
213-
214+
self._groups: Dict[int, list[str]] = {}
214215
self.add_metrics(metrics, *additional_metrics)
215216

216217
@property
@@ -338,7 +339,7 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
338339
of just passed by reference
339340
340341
"""
341-
if not self._state_is_copy:
342+
if not self._state_is_copy and self._groups_checked:
342343
for cg in self._groups.values():
343344
m0 = getattr(self, cg[0])
344345
for i in range(1, len(cg)):
@@ -495,7 +496,6 @@ def add_metrics(
495496
"Unknown input to MetricCollection. Expected, `Metric`, `MetricCollection` or `dict`/`sequence` of the"
496497
f" previous, but got {metrics}"
497498
)
498-
499499
self._groups_checked = False
500500
if self._enable_compute_groups:
501501
self._init_compute_groups()
@@ -518,9 +518,15 @@ def _init_compute_groups(self) -> None:
518518
f"Input {metric} in `compute_groups` argument does not match a metric in the collection."
519519
f" Please make sure that {self._enable_compute_groups} matches {self.keys(keep_base=True)}"
520520
)
521+
# add metrics not specified in compute groups as their own group
522+
already_in_group = _flatten(self._groups.values()) # type: ignore
523+
counter = len(self._groups)
524+
for k in self.keys(keep_base=True):
525+
if k not in already_in_group:
526+
self._groups[counter] = [k] # type: ignore
527+
counter += 1
521528
self._groups_checked = True
522529
else:
523-
# Initialize all metrics as their own compute group
524530
self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=True))}
525531

526532
@property

tests/unittests/bases/test_collections.py

+22
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,28 @@ def test_compute_group_define_by_user():
573573
assert m.compute()
574574

575575

576+
def test_compute_group_define_by_user_outside_specs():
577+
"""Check that user can provide compute groups with missing metrics in the specs."""
578+
m = MetricCollection(
579+
MulticlassConfusionMatrix(3),
580+
MulticlassRecall(3),
581+
MulticlassPrecision(3),
582+
MulticlassAccuracy(3),
583+
compute_groups=[["MulticlassRecall", "MulticlassPrecision"]],
584+
)
585+
assert m._groups_checked
586+
assert m.compute_groups == {
587+
0: ["MulticlassRecall", "MulticlassPrecision"],
588+
1: ["MulticlassConfusionMatrix"],
589+
2: ["MulticlassAccuracy"],
590+
}
591+
592+
preds = torch.randn(10, 3).softmax(dim=-1)
593+
target = torch.randint(3, (10,))
594+
m.update(preds, target)
595+
assert m.compute()
596+
597+
576598
def test_classwise_wrapper_compute_group():
577599
"""Check that user can provide compute groups."""
578600
classwise_accuracy = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="accuracy")

0 commit comments

Comments
 (0)