Skip to content

Commit a9505cb

Browse files
rbedyakinpre-commit-ci[bot]BordaSkafteNickimergify[bot]
authored
fix for MetricCollection.update gives identical results (#2944)
* fix issue 2916 * chlog * linter --------- Co-authored-by: Roman <Roman> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka B <j.borovec+github@gmail.com> Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 95e3bfa commit a9505cb

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2727

2828
### Fixed
2929

30-
-
30+
- Fixed `MetricCollection.update` gives identical results ([#2944](https://github.com/Lightning-AI/torchmetrics/issues/2944))
3131

3232

3333
---

src/torchmetrics/collections.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,19 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool:
314314
if type(state1) != type(state2): # noqa: E721
315315
return False
316316

317-
if isinstance(state1, Tensor) and isinstance(state2, Tensor):
318-
return state1.shape == state2.shape and allclose(state1, state2)
317+
if (
318+
isinstance(state1, Tensor)
319+
and isinstance(state2, Tensor)
320+
and not (state1.shape == state2.shape and allclose(state1, state2))
321+
):
322+
return False
319323

320-
if isinstance(state1, list) and isinstance(state2, list):
321-
return all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2))
324+
if (
325+
isinstance(state1, list)
326+
and isinstance(state2, list)
327+
and not (all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2)))
328+
):
329+
return False
322330

323331
return True
324332

tests/unittests/bases/test_collections.py

+31
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
MultilabelAUROC,
3434
MultilabelAveragePrecision,
3535
)
36+
from torchmetrics.text import BLEUScore
3637
from torchmetrics.utilities.checks import _allclose_recursive
3738
from unittests._helpers import seed_all
3839
from unittests._helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum
@@ -743,3 +744,33 @@ def compute(self):
743744
# Print the calculated metrics
744745
assert "my_prefix/accuracy/my_postfix" in res
745746
assert "my_prefix/precision/my_postfix" in res
747+
748+
749+
def test_collection_update():
750+
"""Test that metric collection updates metrics.
751+
752+
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2916
753+
754+
"""
755+
metrics = MetricCollection({
756+
"bleu-1": BLEUScore(1),
757+
"bleu-2": BLEUScore(2),
758+
"bleu-3": BLEUScore(3),
759+
"bleu-4": BLEUScore(4),
760+
})
761+
762+
preds = ["the cat is on the mat"]
763+
target = [["there is a cat on the mat", "a cat is on the mat"]]
764+
765+
metrics.update(preds, target)
766+
actual = metrics.compute()
767+
768+
expected = {
769+
"bleu-1": torch.tensor(0.8333),
770+
"bleu-2": torch.tensor(0.8165),
771+
"bleu-3": torch.tensor(0.7937),
772+
"bleu-4": torch.tensor(0.7598),
773+
}
774+
775+
for k, v in expected.items():
776+
torch.testing.assert_close(actual=actual.get(k), expected=v, rtol=1e-4, atol=1e-4)

0 commit comments

Comments
 (0)