Skip to content

Commit 0423a2c

Browse files
authored
Fix integration between classwise wrapper and metric tracker (#3004)
* fix integration between wrapper * add tests * changelog * fix typing
1 parent 56f39dd commit 0423a2c

File tree

4 files changed

+15
-3
lines changed

4 files changed

+15
-3
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3636

3737
### Fixed
3838

39-
-
39+
- Fixed integration between classwise wrapper and metric tracker ([#3004](https://github.com/PyTorchLightning/metrics/pull/3004))
4040

4141

4242
---

src/torchmetrics/wrappers/classwise.py

+5
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ def __init__(
139139

140140
self._update_count = 1
141141

142+
@property
143+
def higher_is_better(self) -> Optional[bool]: # type: ignore
144+
"""Return if the metric is higher the better."""
145+
return self.metric.higher_is_better
146+
142147
def _filter_kwargs(self, **kwargs: Any) -> dict[str, Any]:
143148
"""Filter kwargs for the metric."""
144149
return self.metric._filter_kwargs(**kwargs)

src/torchmetrics/wrappers/tracker.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
2525
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val
2626
from torchmetrics.utilities.prints import rank_zero_warn
27+
from torchmetrics.wrappers import ClasswiseWrapper
2728

2829
if not _MATPLOTLIB_AVAILABLE:
2930
__doctest_skip__ = ["MetricTracker.plot"]
@@ -269,7 +270,7 @@ def best_metric(
269270
return None, None
270271
return None
271272

272-
if isinstance(self._base_metric, Metric):
273+
if isinstance(self._base_metric, Metric) and not isinstance(self._base_metric, ClasswiseWrapper):
273274
fn = torch.max if self.maximize else torch.min
274275
try:
275276
value, idx = fn(res, 0)

tests/unittests/wrappers/test_tracker.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
2828
from torchmetrics.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_6
29-
from torchmetrics.wrappers import MetricTracker, MultioutputWrapper
29+
from torchmetrics.wrappers import ClasswiseWrapper, MetricTracker, MultioutputWrapper
3030
from unittests._helpers import seed_all
3131

3232
seed_all(42)
@@ -101,6 +101,11 @@ def test_raises_error_if_increment_not_called(method, method_input):
101101
(torch.randn(50), torch.randn(50)),
102102
[False, False],
103103
),
104+
(
105+
ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)),
106+
(torch.randint(3, (50,)), torch.randint(3, (50,))),
107+
True,
108+
),
104109
],
105110
)
106111
def test_tracker(base_metric, metric_input, maximize):
@@ -244,6 +249,7 @@ def test_tracker_futurewarning():
244249
MeanAbsoluteError(),
245250
MulticlassAccuracy(num_classes=10),
246251
MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
252+
ClasswiseWrapper(MulticlassAccuracy(num_classes=10, average=None)),
247253
],
248254
)
249255
def test_tracker_higher_is_better_integration(base_metric):

0 commit comments

Comments
 (0)