24
24
25
25
from torchmetrics .metric import Metric
26
26
from torchmetrics .utilities import rank_zero_warn
27
- from torchmetrics .utilities .data import _flatten_dict , allclose
27
+ from torchmetrics .utilities .data import _flatten , _flatten_dict , allclose
28
28
from torchmetrics .utilities .imports import _MATPLOTLIB_AVAILABLE
29
29
from torchmetrics .utilities .plot import _AX_TYPE , _PLOT_OUT_TYPE , plot_single_or_multi_val
30
30
@@ -90,7 +90,9 @@ class name as key for the output dict.
90
90
due to the internal logic of ``forward`` preventing this. Secondly, since we compute groups share metric
91
91
states by reference, calling ``.items()``, ``.values()`` etc. on the metric collection will break this
92
92
reference and a copy of states are instead returned in this case (reference will be reestablished on the next
93
- call to ``update``).
93
+ call to ``update``). Do note that for the time being that if you are manually specifying compute groups in
94
+ nested collections, these are not compatible with the compute groups of the parent collection and will be
95
+ overridden.
94
96
95
97
.. important::
96
98
Metric collections can be nested at initialization (see last example) but the output of the collection will
@@ -192,7 +194,6 @@ class name of the metric:
192
194
"""
193
195
194
196
_modules : dict [str , Metric ] # type: ignore[assignment]
195
- _groups : Dict [int , List [str ]]
196
197
__jit_unused_properties__ : ClassVar [list [str ]] = ["metric_state" ]
197
198
198
199
def __init__ (
@@ -210,7 +211,7 @@ def __init__(
210
211
self ._enable_compute_groups = compute_groups
211
212
self ._groups_checked : bool = False
212
213
self ._state_is_copy : bool = False
213
-
214
+ self . _groups : Dict [ int , list [ str ]] = {}
214
215
self .add_metrics (metrics , * additional_metrics )
215
216
216
217
@property
@@ -338,7 +339,7 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
338
339
of just passed by reference
339
340
340
341
"""
341
- if not self ._state_is_copy :
342
+ if not self ._state_is_copy and self . _groups_checked :
342
343
for cg in self ._groups .values ():
343
344
m0 = getattr (self , cg [0 ])
344
345
for i in range (1 , len (cg )):
@@ -495,7 +496,6 @@ def add_metrics(
495
496
"Unknown input to MetricCollection. Expected, `Metric`, `MetricCollection` or `dict`/`sequence` of the"
496
497
f" previous, but got { metrics } "
497
498
)
498
-
499
499
self ._groups_checked = False
500
500
if self ._enable_compute_groups :
501
501
self ._init_compute_groups ()
@@ -518,9 +518,15 @@ def _init_compute_groups(self) -> None:
518
518
f"Input { metric } in `compute_groups` argument does not match a metric in the collection."
519
519
f" Please make sure that { self ._enable_compute_groups } matches { self .keys (keep_base = True )} "
520
520
)
521
+ # add metrics not specified in compute groups as their own group
522
+ already_in_group = _flatten (self ._groups .values ()) # type: ignore
523
+ counter = len (self ._groups )
524
+ for k in self .keys (keep_base = True ):
525
+ if k not in already_in_group :
526
+ self ._groups [counter ] = [k ] # type: ignore
527
+ counter += 1
521
528
self ._groups_checked = True
522
529
else :
523
- # Initialize all metrics as their own compute group
524
530
self ._groups = {i : [str (k )] for i , k in enumerate (self .keys (keep_base = True ))}
525
531
526
532
@property
0 commit comments