Skip to content

Commit 28c75ea

Browse files
authored
Merge branch 'master' into newmetric/dists
2 parents a7cc5d9 + 0423a2c commit 28c75ea

File tree

8 files changed

+606
-392
lines changed

8 files changed

+606
-392
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3939

4040
### Fixed
4141

42-
-
42+
- Fixed integration between classwise wrapper and metric tracker ([#3004](https://github.com/PyTorchLightning/metrics/pull/3004))
4343

4444

4545
---

src/torchmetrics/detection/helpers.py

+468-1
Large diffs are not rendered by default.

src/torchmetrics/detection/mean_ap.py

+120-385
Large diffs are not rendered by default.

src/torchmetrics/utilities/backends.py

Whitespace-only changes.

src/torchmetrics/wrappers/classwise.py

+5
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ def __init__(
139139

140140
self._update_count = 1
141141

142+
@property
143+
def higher_is_better(self) -> Optional[bool]: # type: ignore
144+
"""Return if the metric is higher the better."""
145+
return self.metric.higher_is_better
146+
142147
def _filter_kwargs(self, **kwargs: Any) -> dict[str, Any]:
143148
"""Filter kwargs for the metric."""
144149
return self.metric._filter_kwargs(**kwargs)

src/torchmetrics/wrappers/tracker.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
2525
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val
2626
from torchmetrics.utilities.prints import rank_zero_warn
27+
from torchmetrics.wrappers import ClasswiseWrapper
2728

2829
if not _MATPLOTLIB_AVAILABLE:
2930
__doctest_skip__ = ["MetricTracker.plot"]
@@ -269,7 +270,7 @@ def best_metric(
269270
return None, None
270271
return None
271272

272-
if isinstance(self._base_metric, Metric):
273+
if isinstance(self._base_metric, Metric) and not isinstance(self._base_metric, ClasswiseWrapper):
273274
fn = torch.max if self.maximize else torch.min
274275
try:
275276
value, idx = fn(res, 0)

tests/unittests/detection/test_map.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _generate_coco_inputs(iou_type):
4848
and should therefore correspond directly to the result on the webpage
4949
5050
"""
51-
batched_preds, batched_target = MeanAveragePrecision.coco_to_tm(
51+
batched_preds, batched_target = MeanAveragePrecision().coco_to_tm(
5252
_DETECTION_BBOX if iou_type == "bbox" else _DETECTION_SEGM, _DETECTION_VAL, iou_type
5353
)
5454

@@ -74,7 +74,7 @@ def test_tm_to_coco(tmpdir, iou_type, backend):
7474
for bp, bt in zip(preds, target):
7575
metric.update(bp, bt)
7676
metric.tm_to_coco(f"{tmpdir}/tm_map_input")
77-
preds_2, target_2 = MeanAveragePrecision.coco_to_tm(
77+
preds_2, target_2 = MeanAveragePrecision().coco_to_tm(
7878
f"{tmpdir}/tm_map_input_preds.json",
7979
f"{tmpdir}/tm_map_input_target.json",
8080
iou_type=iou_type,
@@ -246,7 +246,7 @@ def test_compare_both_same_time(tmpdir, backend):
246246
combined = [{**box, **seg} for box, seg in zip(boxes, segmentations)]
247247
with open(f"{tmpdir}/combined.json", "w") as f:
248248
json.dump(combined, f)
249-
batched_preds, batched_target = MeanAveragePrecision.coco_to_tm(
249+
batched_preds, batched_target = MeanAveragePrecision().coco_to_tm(
250250
f"{tmpdir}/combined.json", _DETECTION_VAL, iou_type=["bbox", "segm"]
251251
)
252252
batched_preds = [batched_preds[10 * i : 10 * (i + 1)] for i in range(10)]

tests/unittests/wrappers/test_tracker.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
2828
from torchmetrics.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_6
29-
from torchmetrics.wrappers import MetricTracker, MultioutputWrapper
29+
from torchmetrics.wrappers import ClasswiseWrapper, MetricTracker, MultioutputWrapper
3030
from unittests._helpers import seed_all
3131

3232
seed_all(42)
@@ -101,6 +101,11 @@ def test_raises_error_if_increment_not_called(method, method_input):
101101
(torch.randn(50), torch.randn(50)),
102102
[False, False],
103103
),
104+
(
105+
ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)),
106+
(torch.randint(3, (50,)), torch.randint(3, (50,))),
107+
True,
108+
),
104109
],
105110
)
106111
def test_tracker(base_metric, metric_input, maximize):
@@ -244,6 +249,7 @@ def test_tracker_futurewarning():
244249
MeanAbsoluteError(),
245250
MulticlassAccuracy(num_classes=10),
246251
MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
252+
ClasswiseWrapper(MulticlassAccuracy(num_classes=10, average=None)),
247253
],
248254
)
249255
def test_tracker_higher_is_better_integration(base_metric):

0 commit comments

Comments
 (0)