Skip to content

Commit e514d25

Browse files
committed
fix integration between wrapper
1 parent d6a1ad2 commit e514d25

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/torchmetrics/wrappers/classwise.py

+5
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ def __init__(
139139

140140
self._update_count = 1
141141

142+
@property
143+
def higher_is_better(self) -> Optional[bool]:
144+
"""Return if the metric is higher the better."""
145+
return self.metric.higher_is_better
146+
142147
def _filter_kwargs(self, **kwargs: Any) -> dict[str, Any]:
143148
"""Filter kwargs for the metric."""
144149
return self.metric._filter_kwargs(**kwargs)

src/torchmetrics/wrappers/tracker.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
2525
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val
2626
from torchmetrics.utilities.prints import rank_zero_warn
27+
from torchmetrics.wrappers import ClasswiseWrapper
2728

2829
if not _MATPLOTLIB_AVAILABLE:
2930
__doctest_skip__ = ["MetricTracker.plot"]
@@ -269,7 +270,7 @@ def best_metric(
269270
return None, None
270271
return None
271272

272-
if isinstance(self._base_metric, Metric):
273+
if isinstance(self._base_metric, Metric) and not isinstance(self._base_metric, ClasswiseWrapper):
273274
fn = torch.max if self.maximize else torch.min
274275
try:
275276
value, idx = fn(res, 0)

0 commit comments

Comments
 (0)