Skip to content

Commit 4d9c843

Browse files
rittik9Bordapre-commit-ci[bot]
authored
Fix top_k for multiclass-f1score (#2839)
** Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka B <j.borovec+github@gmail.com>
1 parent 8827e64 commit 4d9c843

File tree

6 files changed

+258
-25
lines changed

6 files changed

+258
-25
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3636
- Fixed issue with shared state in metric collection when using dice score ([#2848](https://github.com/PyTorchLightning/metrics/pull/2848))
3737

3838

39+
- Fixed `top_k` for `multiclassf1score` with one-hot encoding ([#2839](https://github.com/Lightning-AI/torchmetrics/issues/2839))
40+
41+
3942
---
4043

4144
## [1.6.0] - 2024-11-12

src/torchmetrics/functional/classification/stat_scores.py

+27
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,30 @@ def _multiclass_stat_scores_format(
340340
return preds, target
341341

342342

343+
def _refine_preds_oh(preds: Tensor, preds_oh: Tensor, target: Tensor, top_k: int) -> Tensor:
344+
"""Refines prediction one-hot encodings by replacing entries with target one-hot when there's an intersection.
345+
346+
When no intersection is found between the top-k predictions and target, uses the top-1 prediction.
347+
348+
Args:
349+
preds: Original prediction tensor with probabilities/logits
350+
preds_oh: Current one-hot encoded predictions from top-k selection
351+
target: Target tensor with class indices
352+
top_k: Number of top predictions to consider
353+
354+
Returns:
355+
Refined one-hot encoded predictions tensor
356+
357+
"""
358+
preds = preds.squeeze()
359+
target = target.squeeze()
360+
top_k_indices = torch.topk(preds, k=top_k, dim=1).indices
361+
top_1_indices = top_k_indices[:, 0]
362+
target_in_topk = torch.any(top_k_indices == target.unsqueeze(1), dim=1)
363+
result = torch.where(target_in_topk, target, top_1_indices)
364+
return torch.zeros_like(preds_oh, dtype=torch.int32).scatter_(-1, result.unsqueeze(1).unsqueeze(1), 1)
365+
366+
343367
def _multiclass_stat_scores_update(
344368
preds: Tensor,
345369
target: Tensor,
@@ -371,13 +395,16 @@ def _multiclass_stat_scores_update(
371395

372396
if top_k > 1:
373397
preds_oh = torch.movedim(select_topk(preds, topk=top_k, dim=1), 1, -1)
398+
preds_oh = _refine_preds_oh(preds, preds_oh, target, top_k)
374399
else:
375400
preds_oh = torch.nn.functional.one_hot(
376401
preds.long(), num_classes + 1 if ignore_index is not None and not ignore_in else num_classes
377402
)
403+
378404
target_oh = torch.nn.functional.one_hot(
379405
target.long(), num_classes + 1 if ignore_index is not None and not ignore_in else num_classes
380406
)
407+
381408
if ignore_index is not None:
382409
if 0 <= ignore_index <= num_classes - 1:
383410
target_oh[target == ignore_index, :] = -1

tests/unittests/classification/test_f_beta.py

+102-5
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,19 @@ def test_multiclass_fbeta_score_half_gpu(self, inputs, module, functional, compa
377377

378378

379379
_mc_k_target = torch.tensor([0, 1, 2])
380-
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
380+
_mc_k_preds = torch.tensor([
381+
[0.35, 0.4, 0.25],
382+
[0.1, 0.5, 0.4],
383+
[0.2, 0.1, 0.7],
384+
])
385+
386+
_mc_k_target2 = torch.tensor([0, 1, 2, 0])
387+
_mc_k_preds2 = torch.tensor([
388+
[0.1, 0.2, 0.7],
389+
[0.4, 0.4, 0.2],
390+
[0.3, 0.3, 0.4],
391+
[0.3, 0.3, 0.4],
392+
])
381393

382394

383395
@pytest.mark.parametrize(
@@ -391,7 +403,33 @@ def test_multiclass_fbeta_score_half_gpu(self, inputs, module, functional, compa
391403
("k", "preds", "target", "average", "expected_fbeta", "expected_f1"),
392404
[
393405
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)),
394-
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6), torch.tensor(2 / 3)),
406+
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0), torch.tensor(1.0)),
407+
(1, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.25), torch.tensor(0.25)),
408+
(2, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.75), torch.tensor(0.75)),
409+
(3, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(1.0), torch.tensor(1.0)),
410+
(1, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.2381), torch.tensor(0.1667)),
411+
(2, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.7963), torch.tensor(0.7778)),
412+
(3, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(1.0), torch.tensor(1.0)),
413+
(1, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.1786), torch.tensor(0.1250)),
414+
(2, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.7361), torch.tensor(0.7500)),
415+
(3, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(1.0), torch.tensor(1.0)),
416+
(
417+
1,
418+
_mc_k_preds2,
419+
_mc_k_target2,
420+
"none",
421+
torch.tensor([0.0000, 0.0000, 0.7143]),
422+
torch.tensor([0.0000, 0.0000, 0.5000]),
423+
),
424+
(
425+
2,
426+
_mc_k_preds2,
427+
_mc_k_target2,
428+
"none",
429+
torch.tensor([0.5556, 1.0000, 0.8333]),
430+
torch.tensor([0.6667, 1.0000, 0.6667]),
431+
),
432+
(3, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([1.0, 1.0, 1.0]), torch.tensor([1.0, 1.0, 1.0])),
395433
],
396434
)
397435
def test_top_k(
@@ -404,14 +442,73 @@ def test_top_k(
404442
expected_fbeta: Tensor,
405443
expected_f1: Tensor,
406444
):
407-
"""A simple test to check that top_k works as expected."""
445+
"""A comprehensive test to check that top_k works as expected."""
408446
class_metric = metric_class(top_k=k, average=average, num_classes=3)
409447
class_metric.update(preds, target)
410448

411449
result = expected_fbeta if class_metric.beta != 1.0 else expected_f1
412450

413-
assert torch.isclose(class_metric.compute(), result)
414-
assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)
451+
assert torch.allclose(class_metric.compute(), result, atol=1e-4, rtol=1e-4)
452+
assert torch.allclose(
453+
metric_fn(preds, target, top_k=k, average=average, num_classes=3), result, atol=1e-4, rtol=1e-4
454+
)
455+
456+
457+
@pytest.mark.parametrize("num_classes", [5])
458+
def test_multiclassf1score_with_top_k(num_classes):
459+
"""Test that F1 score increases monotonically with top_k and equals 1 when top_k equals num_classes.
460+
461+
Args:
462+
num_classes: Number of classes in the classification task.
463+
464+
The test verifies two properties:
465+
1. F1 score increases or stays the same as top_k increases
466+
2. F1 score equals 1 when top_k equals num_classes
467+
468+
"""
469+
preds = torch.randn(200, num_classes).softmax(dim=-1)
470+
target = torch.randint(num_classes, (200,))
471+
472+
previous_score = 0.0
473+
for k in range(1, num_classes + 1):
474+
f1_score = MulticlassF1Score(num_classes=num_classes, top_k=k, average="macro")
475+
score = f1_score(preds, target)
476+
477+
assert score >= previous_score, f"F1 score did not increase for top_k={k}"
478+
previous_score = score
479+
480+
if k == num_classes:
481+
assert torch.isclose(
482+
score, torch.tensor(1.0)
483+
), f"F1 score is not 1 for top_k={k} when num_classes={num_classes}"
484+
485+
486+
def test_multiclass_f1_score_top_k_equivalence():
487+
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/1653.
488+
489+
Test that top-k F1 score is equivalent to corrected top-1 F1 score.
490+
"""
491+
num_classes = 5
492+
493+
preds = torch.randn(200, num_classes).softmax(dim=-1)
494+
target = torch.randint(num_classes, (200,))
495+
496+
f1_val_top3 = MulticlassF1Score(num_classes=num_classes, top_k=3, average="macro")
497+
f1_val_top1 = MulticlassF1Score(num_classes=num_classes, top_k=1, average="macro")
498+
499+
pred_top_3 = torch.argsort(preds, dim=1, descending=True)[:, :3]
500+
pred_top_1 = pred_top_3[:, 0]
501+
502+
target_in_top3 = (target.unsqueeze(1) == pred_top_3).any(dim=1)
503+
504+
pred_corrected_top3 = torch.where(target_in_top3, target, pred_top_1)
505+
506+
score_top3 = f1_val_top3(preds, target)
507+
score_corrected = f1_val_top1(pred_corrected_top3, target)
508+
509+
assert torch.isclose(
510+
score_top3, score_corrected
511+
), f"Top-3 F1 score ({score_top3}) does not match corrected top-1 F1 score ({score_corrected})"
415512

416513

417514
def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division):

tests/unittests/classification/test_precision_recall.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,16 @@ def test_multiclass_precision_recall_half_gpu(self, inputs, module, functional,
385385
_mc_k_target = tensor([0, 1, 2])
386386
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
387387

388-
_mc_k_targets2 = torch.tensor([0, 0, 2])
389-
_mc_k_preds2 = torch.tensor([[0.9, 0.1, 0.0], [0.9, 0.1, 0.0], [0.9, 0.1, 0.0]])
388+
_mc_k_targets2 = tensor([0, 0, 2])
389+
_mc_k_preds2 = tensor([[0.9, 0.1, 0.0], [0.9, 0.1, 0.0], [0.9, 0.1, 0.0]])
390+
391+
_mc_k_target3 = tensor([0, 1, 2, 0])
392+
_mc_k_preds3 = tensor([
393+
[0.1, 0.2, 0.7],
394+
[0.4, 0.4, 0.2],
395+
[0.3, 0.3, 0.4],
396+
[0.3, 0.3, 0.4],
397+
])
390398

391399

392400
@pytest.mark.parametrize(
@@ -395,10 +403,24 @@ def test_multiclass_precision_recall_half_gpu(self, inputs, module, functional,
395403
@pytest.mark.parametrize(
396404
("k", "preds", "target", "average", "expected_prec", "expected_recall"),
397405
[
398-
(1, _mc_k_preds, _mc_k_target, "micro", tensor(2 / 3), tensor(2 / 3)),
399-
(2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2), tensor(1.0)),
400-
(1, _mc_k_preds2, _mc_k_targets2, "macro", tensor(1 / 3), tensor(1 / 2)),
401-
(2, _mc_k_preds2, _mc_k_targets2, "macro", tensor(1 / 3), tensor(1 / 2)),
406+
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)),
407+
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0), torch.tensor(1.0)),
408+
(3, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0), torch.tensor(1.0)),
409+
(1, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1 / 3), torch.tensor(1 / 2)),
410+
(2, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1 / 3), torch.tensor(1 / 2)),
411+
(3, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1.0), torch.tensor(1.0)),
412+
(1, _mc_k_preds3, _mc_k_target3, "macro", torch.tensor(0.1111), torch.tensor(0.3333)),
413+
(2, _mc_k_preds3, _mc_k_target3, "macro", torch.tensor(0.8333), torch.tensor(0.8333)),
414+
(3, _mc_k_preds3, _mc_k_target3, "macro", torch.tensor(1.0), torch.tensor(1.0)),
415+
(1, _mc_k_preds3, _mc_k_target3, "micro", torch.tensor(0.2500), torch.tensor(0.2500)),
416+
(2, _mc_k_preds3, _mc_k_target3, "micro", torch.tensor(0.7500), torch.tensor(0.7500)),
417+
(3, _mc_k_preds3, _mc_k_target3, "micro", torch.tensor(1.0), torch.tensor(1.0)),
418+
(1, _mc_k_preds3, _mc_k_target3, "weighted", torch.tensor(0.0833), torch.tensor(0.2500)),
419+
(2, _mc_k_preds3, _mc_k_target3, "weighted", torch.tensor(0.8750), torch.tensor(0.7500)),
420+
(3, _mc_k_preds3, _mc_k_target3, "weighted", torch.tensor(1.0), torch.tensor(1.0)),
421+
(1, _mc_k_preds3, _mc_k_target3, "none", torch.tensor([0.0000, 0.0000, 0.3333]), torch.tensor([0.0, 0.0, 1.0])),
422+
(2, _mc_k_preds3, _mc_k_target3, "none", torch.tensor([1.0000, 1.0000, 0.5000]), torch.tensor([0.5, 1.0, 1.0])),
423+
(3, _mc_k_preds3, _mc_k_target3, "none", torch.tensor([1.0, 1.0, 1.0]), torch.tensor([1.0, 1.0, 1.0])),
402424
],
403425
)
404426
def test_top_k(
@@ -411,14 +433,16 @@ def test_top_k(
411433
expected_prec: Tensor,
412434
expected_recall: Tensor,
413435
):
414-
"""A simple test to check that top_k works as expected."""
436+
"""A test to validate top_k functionality for precision and recall."""
415437
class_metric = metric_class(top_k=k, average=average, num_classes=3)
416438
class_metric.update(preds, target)
417439

418440
result = expected_prec if metric_class.__name__ == "MulticlassPrecision" else expected_recall
419441

420-
assert torch.equal(class_metric.compute(), result)
421-
assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)
442+
assert torch.allclose(class_metric.compute(), result, atol=1e-4, rtol=1e-4)
443+
assert torch.allclose(
444+
metric_fn(preds, target, top_k=k, average=average, num_classes=3), result, atol=1e-4, rtol=1e-4
445+
)
422446

423447

424448
def _reference_sklearn_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division):

tests/unittests/classification/test_specificity.py

+32-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
from scipy.special import expit as sigmoid
2020
from sklearn.metrics import confusion_matrix as sk_confusion_matrix
21-
from torch import Tensor, tensor
21+
from torch import Tensor
2222
from torchmetrics.classification.specificity import (
2323
BinarySpecificity,
2424
MulticlassSpecificity,
@@ -355,24 +355,49 @@ def test_multiclass_specificity_dtype_gpu(self, inputs, dtype):
355355
)
356356

357357

358-
_mc_k_target = tensor([0, 1, 2])
359-
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
358+
_mc_k_target = torch.tensor([0, 1, 2])
359+
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
360+
361+
_mc_k_target2 = torch.tensor([0, 1, 2, 0])
362+
_mc_k_preds2 = torch.tensor([
363+
[0.1, 0.2, 0.7],
364+
[0.4, 0.4, 0.2],
365+
[0.3, 0.3, 0.4],
366+
[0.3, 0.3, 0.4],
367+
])
360368

361369

362370
@pytest.mark.parametrize(
363371
("k", "preds", "target", "average", "expected_spec"),
364372
[
365-
(1, _mc_k_preds, _mc_k_target, "micro", tensor(5 / 6)),
366-
(2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2)),
373+
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6)),
374+
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0)),
375+
(1, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.6111)),
376+
(2, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.8889)),
377+
(3, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(1.0)),
378+
(1, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.6250)),
379+
(2, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.8750)),
380+
(3, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(1.0)),
381+
(1, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.5833)),
382+
(2, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.9167)),
383+
(3, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(1.0)),
384+
(1, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([0.5000, 1.0000, 0.3333])),
385+
(2, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([1.0000, 1.0000, 0.6667])),
386+
(3, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([1.0, 1.0, 1.0])),
367387
],
368388
)
369389
def test_top_k(k: int, preds: Tensor, target: Tensor, average: str, expected_spec: Tensor):
370390
"""A simple test to check that top_k works as expected."""
371391
class_metric = MulticlassSpecificity(top_k=k, average=average, num_classes=3)
372392
class_metric.update(preds, target)
373393

374-
assert torch.equal(class_metric.compute(), expected_spec)
375-
assert torch.equal(multiclass_specificity(preds, target, top_k=k, average=average, num_classes=3), expected_spec)
394+
assert torch.allclose(class_metric.compute(), expected_spec, atol=1e-4, rtol=1e-4)
395+
assert torch.allclose(
396+
multiclass_specificity(preds, target, top_k=k, average=average, num_classes=3),
397+
expected_spec,
398+
atol=1e-4,
399+
rtol=1e-4,
400+
)
376401

377402

378403
def _reference_specificity_multilabel_global(preds, target, ignore_index, average):

0 commit comments

Comments
 (0)