Skip to content

Commit 9b90288

Browse files
committed
add tests
1 parent e514d25 commit 9b90288

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

tests/unittests/wrappers/test_tracker.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
2828
from torchmetrics.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_6
29-
from torchmetrics.wrappers import MetricTracker, MultioutputWrapper
29+
from torchmetrics.wrappers import ClasswiseWrapper, MetricTracker, MultioutputWrapper
3030
from unittests._helpers import seed_all
3131

3232
seed_all(42)
@@ -101,6 +101,11 @@ def test_raises_error_if_increment_not_called(method, method_input):
101101
(torch.randn(50), torch.randn(50)),
102102
[False, False],
103103
),
104+
(
105+
ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)),
106+
(torch.randint(3, (50,)), torch.randint(3, (50,))),
107+
True,
108+
),
104109
],
105110
)
106111
def test_tracker(base_metric, metric_input, maximize):
@@ -244,6 +249,7 @@ def test_tracker_futurewarning():
244249
MeanAbsoluteError(),
245250
MulticlassAccuracy(num_classes=10),
246251
MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
252+
ClasswiseWrapper(MulticlassAccuracy(num_classes=10, average=None)),
247253
],
248254
)
249255
def test_tracker_higher_is_better_integration(base_metric):

0 commit comments

Comments
 (0)