@@ -377,7 +377,19 @@ def test_multiclass_fbeta_score_half_gpu(self, inputs, module, functional, compa
377
377
378
378
379
379
_mc_k_target = torch .tensor ([0 , 1 , 2 ])
380
- _mc_k_preds = torch .tensor ([[0.35 , 0.4 , 0.25 ], [0.1 , 0.5 , 0.4 ], [0.2 , 0.1 , 0.7 ]])
380
+ _mc_k_preds = torch .tensor ([
381
+ [0.35 , 0.4 , 0.25 ],
382
+ [0.1 , 0.5 , 0.4 ],
383
+ [0.2 , 0.1 , 0.7 ],
384
+ ])
385
+
386
+ _mc_k_target2 = torch .tensor ([0 , 1 , 2 , 0 ])
387
+ _mc_k_preds2 = torch .tensor ([
388
+ [0.1 , 0.2 , 0.7 ],
389
+ [0.4 , 0.4 , 0.2 ],
390
+ [0.3 , 0.3 , 0.4 ],
391
+ [0.3 , 0.3 , 0.4 ],
392
+ ])
381
393
382
394
383
395
@pytest .mark .parametrize (
@@ -391,7 +403,33 @@ def test_multiclass_fbeta_score_half_gpu(self, inputs, module, functional, compa
391
403
("k" , "preds" , "target" , "average" , "expected_fbeta" , "expected_f1" ),
392
404
[
393
405
(1 , _mc_k_preds , _mc_k_target , "micro" , torch .tensor (2 / 3 ), torch .tensor (2 / 3 )),
394
- (2 , _mc_k_preds , _mc_k_target , "micro" , torch .tensor (5 / 6 ), torch .tensor (2 / 3 )),
406
+ (2 , _mc_k_preds , _mc_k_target , "micro" , torch .tensor (1.0 ), torch .tensor (1.0 )),
407
+ (1 , _mc_k_preds2 , _mc_k_target2 , "micro" , torch .tensor (0.25 ), torch .tensor (0.25 )),
408
+ (2 , _mc_k_preds2 , _mc_k_target2 , "micro" , torch .tensor (0.75 ), torch .tensor (0.75 )),
409
+ (3 , _mc_k_preds2 , _mc_k_target2 , "micro" , torch .tensor (1.0 ), torch .tensor (1.0 )),
410
+ (1 , _mc_k_preds2 , _mc_k_target2 , "macro" , torch .tensor (0.2381 ), torch .tensor (0.1667 )),
411
+ (2 , _mc_k_preds2 , _mc_k_target2 , "macro" , torch .tensor (0.7963 ), torch .tensor (0.7778 )),
412
+ (3 , _mc_k_preds2 , _mc_k_target2 , "macro" , torch .tensor (1.0 ), torch .tensor (1.0 )),
413
+ (1 , _mc_k_preds2 , _mc_k_target2 , "weighted" , torch .tensor (0.1786 ), torch .tensor (0.1250 )),
414
+ (2 , _mc_k_preds2 , _mc_k_target2 , "weighted" , torch .tensor (0.7361 ), torch .tensor (0.7500 )),
415
+ (3 , _mc_k_preds2 , _mc_k_target2 , "weighted" , torch .tensor (1.0 ), torch .tensor (1.0 )),
416
+ (
417
+ 1 ,
418
+ _mc_k_preds2 ,
419
+ _mc_k_target2 ,
420
+ "none" ,
421
+ torch .tensor ([0.0000 , 0.0000 , 0.7143 ]),
422
+ torch .tensor ([0.0000 , 0.0000 , 0.5000 ]),
423
+ ),
424
+ (
425
+ 2 ,
426
+ _mc_k_preds2 ,
427
+ _mc_k_target2 ,
428
+ "none" ,
429
+ torch .tensor ([0.5556 , 1.0000 , 0.8333 ]),
430
+ torch .tensor ([0.6667 , 1.0000 , 0.6667 ]),
431
+ ),
432
+ (3 , _mc_k_preds2 , _mc_k_target2 , "none" , torch .tensor ([1.0 , 1.0 , 1.0 ]), torch .tensor ([1.0 , 1.0 , 1.0 ])),
395
433
],
396
434
)
397
435
def test_top_k (
@@ -404,14 +442,73 @@ def test_top_k(
404
442
expected_fbeta : Tensor ,
405
443
expected_f1 : Tensor ,
406
444
):
407
- """A simple test to check that top_k works as expected."""
445
+ """A comprehensive test to check that top_k works as expected."""
408
446
class_metric = metric_class (top_k = k , average = average , num_classes = 3 )
409
447
class_metric .update (preds , target )
410
448
411
449
result = expected_fbeta if class_metric .beta != 1.0 else expected_f1
412
450
413
- assert torch .isclose (class_metric .compute (), result )
414
- assert torch .isclose (metric_fn (preds , target , top_k = k , average = average , num_classes = 3 ), result )
451
+ assert torch .allclose (class_metric .compute (), result , atol = 1e-4 , rtol = 1e-4 )
452
+ assert torch .allclose (
453
+ metric_fn (preds , target , top_k = k , average = average , num_classes = 3 ), result , atol = 1e-4 , rtol = 1e-4
454
+ )
455
+
456
+
457
+ @pytest .mark .parametrize ("num_classes" , [5 ])
458
+ def test_multiclassf1score_with_top_k (num_classes ):
459
+ """Test that F1 score increases monotonically with top_k and equals 1 when top_k equals num_classes.
460
+
461
+ Args:
462
+ num_classes: Number of classes in the classification task.
463
+
464
+ The test verifies two properties:
465
+ 1. F1 score increases or stays the same as top_k increases
466
+ 2. F1 score equals 1 when top_k equals num_classes
467
+
468
+ """
469
+ preds = torch .randn (200 , num_classes ).softmax (dim = - 1 )
470
+ target = torch .randint (num_classes , (200 ,))
471
+
472
+ previous_score = 0.0
473
+ for k in range (1 , num_classes + 1 ):
474
+ f1_score = MulticlassF1Score (num_classes = num_classes , top_k = k , average = "macro" )
475
+ score = f1_score (preds , target )
476
+
477
+ assert score >= previous_score , f"F1 score did not increase for top_k={ k } "
478
+ previous_score = score
479
+
480
+ if k == num_classes :
481
+ assert torch .isclose (
482
+ score , torch .tensor (1.0 )
483
+ ), f"F1 score is not 1 for top_k={ k } when num_classes={ num_classes } "
484
+
485
+
486
+ def test_multiclass_f1_score_top_k_equivalence ():
487
+ """Issue: https://github.com/Lightning-AI/torchmetrics/issues/1653.
488
+
489
+ Test that top-k F1 score is equivalent to corrected top-1 F1 score.
490
+ """
491
+ num_classes = 5
492
+
493
+ preds = torch .randn (200 , num_classes ).softmax (dim = - 1 )
494
+ target = torch .randint (num_classes , (200 ,))
495
+
496
+ f1_val_top3 = MulticlassF1Score (num_classes = num_classes , top_k = 3 , average = "macro" )
497
+ f1_val_top1 = MulticlassF1Score (num_classes = num_classes , top_k = 1 , average = "macro" )
498
+
499
+ pred_top_3 = torch .argsort (preds , dim = 1 , descending = True )[:, :3 ]
500
+ pred_top_1 = pred_top_3 [:, 0 ]
501
+
502
+ target_in_top3 = (target .unsqueeze (1 ) == pred_top_3 ).any (dim = 1 )
503
+
504
+ pred_corrected_top3 = torch .where (target_in_top3 , target , pred_top_1 )
505
+
506
+ score_top3 = f1_val_top3 (preds , target )
507
+ score_corrected = f1_val_top1 (pred_corrected_top3 , target )
508
+
509
+ assert torch .isclose (
510
+ score_top3 , score_corrected
511
+ ), f"Top-3 F1 score ({ score_top3 } ) does not match corrected top-1 F1 score ({ score_corrected } )"
415
512
416
513
417
514
def _reference_sklearn_fbeta_score_multilabel_global (preds , target , sk_fn , ignore_index , average , zero_division ):
0 commit comments