Skip to content

Commit cd24d2b

Browse files
nkaenzigpre-commit-ci[bot]SkafteNicki
authored
Fix issue with shared state of MetricCollection compute group when using DiceScore(average="weighted") (#2848)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
1 parent a968ebe commit cd24d2b

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2727

2828
### Fixed
2929

30-
-
30+
- Fixed issue with shared state in metric collection when using dice score ([#2848](https://github.com/PyTorchLightning/metrics/pull/2848))
3131

3232

3333
---

src/torchmetrics/segmentation/dice.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:
131131
)
132132
self.numerator.append(numerator)
133133
self.denominator.append(denominator)
134-
if self.average == "weighted":
135-
self.support.append(support)
134+
self.support.append(support)
136135

137136
def compute(self) -> Tensor:
138137
"""Computes the Dice Score."""

tests/unittests/segmentation/test_dice.py

+30
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pytest
1717
import torch
1818
from sklearn.metrics import f1_score
19+
from torchmetrics import MetricCollection
1920
from torchmetrics.functional.segmentation.dice import dice_score
2021
from torchmetrics.segmentation.dice import DiceScore
2122

@@ -106,3 +107,32 @@ def test_dice_score_functional(self, preds, target, input_format, include_backgr
106107
"input_format": input_format,
107108
},
108109
)
110+
111+
112+
@pytest.mark.parametrize("compute_groups", [True, False])
113+
def test_dice_score_metric_collection(compute_groups: bool, num_batches: int = 4):
114+
"""Test that the metric works within a metric collection with and without compute groups."""
115+
metric_collection = MetricCollection(
116+
metrics={
117+
"DiceScore (micro)": DiceScore(
118+
num_classes=NUM_CLASSES,
119+
average="micro",
120+
),
121+
"DiceScore (macro)": DiceScore(
122+
num_classes=NUM_CLASSES,
123+
average="macro",
124+
),
125+
"DiceScore (weighted)": DiceScore(
126+
num_classes=NUM_CLASSES,
127+
average="weighted",
128+
),
129+
},
130+
compute_groups=compute_groups,
131+
)
132+
133+
for _ in range(num_batches):
134+
metric_collection.update(_inputs1.preds, _inputs1.target)
135+
result = metric_collection.compute()
136+
137+
assert isinstance(result, dict)
138+
assert len(set(metric_collection.keys()) - set(result.keys())) == 0

0 commit comments

Comments
 (0)