33
33
MultilabelAUROC ,
34
34
MultilabelAveragePrecision ,
35
35
)
36
+ from torchmetrics .regression import PearsonCorrCoef
36
37
from torchmetrics .text import BLEUScore
37
38
from torchmetrics .utilities .checks import _allclose_recursive
38
39
from unittests ._helpers import seed_all
@@ -328,30 +329,35 @@ def compute(self):
328
329
"metrics, expected, preds, target" ,
329
330
[
330
331
# single metric forms its own compute group
331
- (MulticlassAccuracy (num_classes = 3 ), {0 : ["MulticlassAccuracy" ]}, _mc_preds , _mc_target ),
332
+ pytest .param (
333
+ MulticlassAccuracy (num_classes = 3 ), {0 : ["MulticlassAccuracy" ]}, _mc_preds , _mc_target , id = "single_metric"
334
+ ),
332
335
# two metrics of same class forms a compute group
333
- (
336
+ pytest . param (
334
337
{"acc0" : MulticlassAccuracy (num_classes = 3 ), "acc1" : MulticlassAccuracy (num_classes = 3 )},
335
338
{0 : ["acc0" , "acc1" ]},
336
339
_mc_preds ,
337
340
_mc_target ,
341
+ id = "two_metrics_of_same_class" ,
338
342
),
339
343
# two metrics from registry forms a compute group
340
- (
344
+ pytest . param (
341
345
[MulticlassPrecision (num_classes = 3 ), MulticlassRecall (num_classes = 3 )],
342
346
{0 : ["MulticlassPrecision" , "MulticlassRecall" ]},
343
347
_mc_preds ,
344
348
_mc_target ,
349
+ id = "two_metrics_from_registry" ,
345
350
),
346
351
# two metrics from different classes gives two compute groups
347
- (
352
+ pytest . param (
348
353
[MulticlassConfusionMatrix (num_classes = 3 ), MulticlassRecall (num_classes = 3 )],
349
354
{0 : ["MulticlassConfusionMatrix" ], 1 : ["MulticlassRecall" ]},
350
355
_mc_preds ,
351
356
_mc_target ,
357
+ id = "two_metrics_from_different_classes" ,
352
358
),
353
359
# multi group multi metric
354
- (
360
+ pytest . param (
355
361
[
356
362
MulticlassConfusionMatrix (num_classes = 3 ),
357
363
MulticlassCohenKappa (num_classes = 3 ),
@@ -361,9 +367,10 @@ def compute(self):
361
367
{0 : ["MulticlassConfusionMatrix" , "MulticlassCohenKappa" ], 1 : ["MulticlassRecall" , "MulticlassPrecision" ]},
362
368
_mc_preds ,
363
369
_mc_target ,
370
+ id = "multi_group_multi_metric" ,
364
371
),
365
372
# Complex example
366
- (
373
+ pytest . param (
367
374
{
368
375
"acc" : MulticlassAccuracy (num_classes = 3 ),
369
376
"acc2" : MulticlassAccuracy (num_classes = 3 ),
@@ -375,19 +382,21 @@ def compute(self):
375
382
{0 : ["acc" , "acc2" , "f1" , "recall" ], 1 : ["acc3" ], 2 : ["confmat" ]},
376
383
_mc_preds ,
377
384
_mc_target ,
385
+ id = "complex_example" ,
378
386
),
379
387
# With list states
380
- (
388
+ pytest . param (
381
389
[
382
390
MulticlassAUROC (num_classes = 3 , average = "macro" ),
383
391
MulticlassAveragePrecision (num_classes = 3 , average = "macro" ),
384
392
],
385
393
{0 : ["MulticlassAUROC" , "MulticlassAveragePrecision" ]},
386
394
_mc_preds ,
387
395
_mc_target ,
396
+ id = "with_list_states" ,
388
397
),
389
398
# Nested collections
390
- (
399
+ pytest . param (
391
400
[
392
401
MetricCollection (
393
402
MultilabelAUROC (num_labels = 3 , average = "micro" ),
@@ -410,6 +419,7 @@ def compute(self):
410
419
},
411
420
_ml_preds ,
412
421
_ml_target ,
422
+ id = "nested_collections" ,
413
423
),
414
424
],
415
425
)
@@ -796,3 +806,39 @@ def test_collection_update():
796
806
797
807
for k , v in expected .items ():
798
808
torch .testing .assert_close (actual = actual .get (k ), expected = v , rtol = 1e-4 , atol = 1e-4 )
809
+
810
+
811
+ def test_collection_state_being_re_established_after_copy ():
812
+ """Check that shared metrics states when using compute groups are re-established after a copy.
813
+
814
+ See issue: https://github.com/Lightning-AI/torchmetrics/issues/2896
815
+
816
+ """
817
+ m1 , m2 = PearsonCorrCoef (), PearsonCorrCoef ()
818
+ m12 = MetricCollection ({"m1" : m1 , "m2" : m2 }, compute_groups = True )
819
+ x1 , y1 = torch .randn (100 ), torch .randn (100 )
820
+ m12 .update (x1 , y1 )
821
+ assert m12 .compute_groups == {0 : ["m1" , "m2" ]}
822
+
823
+ # Check that the states are pointing to the same location
824
+ assert not m12 ._state_is_copy
825
+ assert m12 .m1 .mean_x .data_ptr () == m12 .m2 .mean_x .data_ptr (), "States should point to the same location"
826
+
827
+ # Break the references between the states
828
+ _ = m12 .items ()
829
+ assert m12 ._state_is_copy
830
+ assert m12 .m1 .mean_x .data_ptr () != m12 .m2 .mean_x .data_ptr (), "States should not point to the same location"
831
+
832
+ # Update should restore the references between the states
833
+ x2 , y2 = torch .randn (100 ), torch .randn (100 )
834
+
835
+ m12 .update (x2 , y2 )
836
+ assert not m12 ._state_is_copy
837
+ assert m12 .m1 .mean_x .data_ptr () == m12 .m2 .mean_x .data_ptr (), "States should point to the same location"
838
+
839
+ x3 , y3 = torch .randn (100 ), torch .randn (100 )
840
+ m12 .update (x3 , y3 )
841
+
842
+ assert not m12 ._state_is_copy
843
+ assert m12 .m1 .mean_x .data_ptr () == m12 .m2 .mean_x .data_ptr (), "States should point to the same location"
844
+ assert m12 ._equal_metric_states (m12 .m1 , m12 .m2 )
0 commit comments