Skip to content

Commit 656191a

Browse files
czmrandcrand-mbeBordaSkafteNickimergify[bot]
authored
torchmetric optimizations (#2943)
* add nan_strategy "disable" to disable nan checks set default in update to None to prevent sync event see https://medium.com/@chaimrand/efficient-metric-collection-in-pytorch-avoiding-the-performance-pitfalls-of-torchmetrics-0dea81413681 for motivation --------- Co-authored-by: Chaim Rand <chaim.rand@mobileye.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka B <j.borovec+github@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent a9505cb commit 656191a

File tree

3 files changed

+62
-42
lines changed

3 files changed

+62
-42
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212

1313
### Added
1414

15-
-
15+
- Added `disable` option to `nan_strategy` in basic aggregation metrics ([#2943](https://github.com/PyTorchLightning/metrics/pull/2943))
1616

1717

1818
### Changed

src/torchmetrics/aggregation.py

+52-40
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818
from torch import Tensor
19+
from typing_extensions import Literal
1920

2021
from torchmetrics.metric import Metric
2122
from torchmetrics.utilities import rank_zero_warn
@@ -38,14 +39,15 @@ class BaseAggregator(Metric):
3839
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
3940
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
4041
- ``'ignore'``: all `nan` values are silently removed
42+
- ``'disable'``: disable all `nan` checks
4143
- a float: if a float is provided will impute any `nan` values with this value
4244
4345
state_name: name of the metric state
4446
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
4547
4648
Raises:
4749
ValueError:
48-
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
50+
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
4951
5052
"""
5153

@@ -57,12 +59,12 @@ def __init__(
5759
self,
5860
fn: Union[Callable, str],
5961
default_value: Union[Tensor, list],
60-
nan_strategy: Union[str, float] = "error",
62+
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "error",
6163
state_name: str = "value",
6264
**kwargs: Any,
6365
) -> None:
6466
super().__init__(**kwargs)
65-
allowed_nan_strategy = ("error", "warn", "ignore")
67+
allowed_nan_strategy = ("error", "warn", "ignore", "disable")
6668
if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float):
6769
raise ValueError(
6870
f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy} but got {nan_strategy}."
@@ -81,26 +83,28 @@ def _cast_and_nan_check_input(
8183
if weight is not None and not isinstance(weight, Tensor):
8284
weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)
8385

84-
nans = torch.isnan(x)
85-
if weight is not None:
86-
nans_weight = torch.isnan(weight)
86+
if self.nan_strategy != "disable":
87+
nans = torch.isnan(x)
88+
if weight is not None:
89+
nans_weight = torch.isnan(weight)
90+
else:
91+
nans_weight = torch.zeros_like(nans).bool()
92+
weight = torch.ones_like(x)
93+
if nans.any() or nans_weight.any():
94+
if self.nan_strategy == "error":
95+
raise RuntimeError("Encountered `nan` values in tensor")
96+
if self.nan_strategy in ("ignore", "warn"):
97+
if self.nan_strategy == "warn":
98+
rank_zero_warn("Encountered `nan` values in tensor. Will be removed.", UserWarning)
99+
x = x[~(nans | nans_weight)]
100+
weight = weight[~(nans | nans_weight)]
101+
else:
102+
if not isinstance(self.nan_strategy, float):
103+
raise ValueError(f"`nan_strategy` shall be float but you pass {self.nan_strategy}")
104+
x[nans | nans_weight] = self.nan_strategy
105+
weight[nans | nans_weight] = 1
87106
else:
88-
nans_weight = torch.zeros_like(nans).bool()
89107
weight = torch.ones_like(x)
90-
if nans.any() or nans_weight.any():
91-
if self.nan_strategy == "error":
92-
raise RuntimeError("Encountered `nan` values in tensor")
93-
if self.nan_strategy in ("ignore", "warn"):
94-
if self.nan_strategy == "warn":
95-
rank_zero_warn("Encountered `nan` values in tensor. Will be removed.", UserWarning)
96-
x = x[~(nans | nans_weight)]
97-
weight = weight[~(nans | nans_weight)]
98-
else:
99-
if not isinstance(self.nan_strategy, float):
100-
raise ValueError(f"`nan_strategy` shall be float but you pass {self.nan_strategy}")
101-
x[nans | nans_weight] = self.nan_strategy
102-
weight[nans | nans_weight] = self.nan_strategy
103-
104108
return x.to(self.dtype), weight.to(self.dtype)
105109

106110
def update(self, value: Union[float, Tensor]) -> None:
@@ -128,13 +132,14 @@ class MaxMetric(BaseAggregator):
128132
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
129133
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
130134
- ``'ignore'``: all `nan` values are silently removed
135+
- ``'disable'``: disable all `nan` checks
131136
- a float: if a float is provided will impute any `nan` values with this value
132137
133138
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
134139
135140
Raises:
136141
ValueError:
137-
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
142+
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
138143
139144
Example:
140145
>>> from torch import tensor
@@ -152,7 +157,7 @@ class MaxMetric(BaseAggregator):
152157

153158
def __init__(
154159
self,
155-
nan_strategy: Union[str, float] = "warn",
160+
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
156161
**kwargs: Any,
157162
) -> None:
158163
super().__init__(
@@ -233,13 +238,14 @@ class MinMetric(BaseAggregator):
233238
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
234239
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
235240
- ``'ignore'``: all `nan` values are silently removed
241+
- ``'disable'``: disable all `nan` checks
236242
- a float: if a float is provided will impute any `nan` values with this value
237243
238244
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
239245
240246
Raises:
241247
ValueError:
242-
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
248+
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
243249
244250
Example:
245251
>>> from torch import tensor
@@ -257,7 +263,7 @@ class MinMetric(BaseAggregator):
257263

258264
def __init__(
259265
self,
260-
nan_strategy: Union[str, float] = "warn",
266+
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
261267
**kwargs: Any,
262268
) -> None:
263269
super().__init__(
@@ -338,13 +344,14 @@ class SumMetric(BaseAggregator):
338344
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
339345
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
340346
- ``'ignore'``: all `nan` values are silently removed
347+
- ``'disable'``: disable all `nan` checks
341348
- a float: if a float is provided will impute any `nan` values with this value
342349
343350
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
344351
345352
Raises:
346353
ValueError:
347-
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
354+
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
348355
349356
Example:
350357
>>> from torch import tensor
@@ -361,7 +368,7 @@ class SumMetric(BaseAggregator):
361368

362369
def __init__(
363370
self,
364-
nan_strategy: Union[str, float] = "warn",
371+
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
365372
**kwargs: Any,
366373
) -> None:
367374
super().__init__(
@@ -443,13 +450,14 @@ class CatMetric(BaseAggregator):
443450
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
444451
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
445452
- ``'ignore'``: all `nan` values are silently removed
453+
- ``'disable'``: disable all `nan` checks
446454
- a float: if a float is provided will impute any `nan` values with this value
447455
448456
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
449457
450458
Raises:
451459
ValueError:
452-
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
460+
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
453461
454462
Example:
455463
>>> from torch import tensor
@@ -466,7 +474,7 @@ class CatMetric(BaseAggregator):
466474

467475
def __init__(
468476
self,
469-
nan_strategy: Union[str, float] = "warn",
477+
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
470478
**kwargs: Any,
471479
) -> None:
472480
super().__init__("cat", [], nan_strategy, **kwargs)
@@ -505,17 +513,18 @@ class MeanMetric(BaseAggregator):
505513
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated (weighted) mean over all inputs received
506514
507515
Args:
508-
nan_strategy: options:
516+
nan_strategy: options:
509517
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
510518
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
511519
- ``'ignore'``: all `nan` values are silently removed
520+
- ``'disable'``: disable all `nan` checks
512521
- a float: if a float is provided will impute any `nan` values with this value
513522
514523
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
515524
516525
Raises:
517526
ValueError:
518-
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
527+
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
519528
520529
Example:
521530
>>> from torchmetrics.aggregation import MeanMetric
@@ -532,7 +541,7 @@ class MeanMetric(BaseAggregator):
532541

533542
def __init__(
534543
self,
535-
nan_strategy: Union[str, float] = "warn",
544+
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
536545
**kwargs: Any,
537546
) -> None:
538547
super().__init__(
@@ -544,22 +553,24 @@ def __init__(
544553
)
545554
self.add_state("weight", default=torch.tensor(0.0, dtype=torch.get_default_dtype()), dist_reduce_fx="sum")
546555

547-
def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0) -> None:
556+
def update(self, value: Union[float, Tensor], weight: Union[float, Tensor, None] = None) -> None:
548557
"""Update state with data.
549558
550559
Args:
551560
value: Either a float or tensor containing data. Additional tensor
552561
dimensions will be flattened
553562
weight: Either a float or tensor containing weights for calculating
554563
the average. Shape of weight should be able to broadcast with
555-
the shape of `value`. Default to `1.0` corresponding to simple
564+
the shape of `value`. Default to None corresponding to simple
556565
harmonic average.
557566
558567
"""
559568
# broadcast weight to value shape
560569
if not isinstance(value, Tensor):
561570
value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
562-
if weight is not None and not isinstance(weight, Tensor):
571+
if weight is None:
572+
weight = torch.ones_like(value)
573+
elif not isinstance(weight, Tensor):
563574
weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)
564575
weight = torch.broadcast_to(weight, value.shape)
565576
value, weight = self._cast_and_nan_check_input(value, weight)
@@ -631,18 +642,18 @@ class RunningMean(Running):
631642
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received
632643
633644
Args:
634-
window: The size of the running window.
635645
nan_strategy: options:
636646
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
637647
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
638648
- ``'ignore'``: all `nan` values are silently removed
649+
- ``'disable'``: disable all `nan` checks
639650
- a float: if a float is provided will impute any `nan` values with this value
640651
641652
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
642653
643654
Raises:
644655
ValueError:
645-
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
656+
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
646657
647658
Example:
648659
>>> from torch import tensor
@@ -665,7 +676,7 @@ class RunningMean(Running):
665676
def __init__(
666677
self,
667678
window: int = 5,
668-
nan_strategy: Union[str, float] = "warn",
679+
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
669680
**kwargs: Any,
670681
) -> None:
671682
super().__init__(base_metric=MeanMetric(nan_strategy=nan_strategy, **kwargs), window=window)
@@ -693,13 +704,14 @@ class RunningSum(Running):
693704
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
694705
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
695706
- ``'ignore'``: all `nan` values are silently removed
707+
- ``'disable'``: disable all `nan` checks
696708
- a float: if a float is provided will impute any `nan` values with this value
697709
698710
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
699711
700712
Raises:
701713
ValueError:
702-
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
714+
If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
703715
704716
Example:
705717
>>> from torch import tensor
@@ -722,7 +734,7 @@ class RunningSum(Running):
722734
def __init__(
723735
self,
724736
window: int = 5,
725-
nan_strategy: Union[str, float] = "warn",
737+
nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
726738
**kwargs: Any,
727739
) -> None:
728740
super().__init__(base_metric=SumMetric(nan_strategy=nan_strategy, **kwargs), window=window)

tests/unittests/bases/test_aggregation.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
import torch
44

5+
from torchmetrics import Metric
56
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric
67
from torchmetrics.collections import MetricCollection
78
from unittests import BATCH_SIZE, NUM_BATCHES
@@ -121,28 +122,35 @@ def test_nan_error(value, nan_strategy, metric_class):
121122
(MinMetric, 2.0, _CASE_1, 2.0),
122123
(MinMetric, "ignore", _CASE_2, 1.0),
123124
(MinMetric, 2.0, _CASE_2, 1.0),
125+
(MinMetric, "disable", _CASE_1, torch.tensor(float("nan"))),
124126
(MaxMetric, "ignore", _CASE_1, -torch.tensor(float("inf"))),
125127
(MaxMetric, 2.0, _CASE_1, 2.0),
126128
(MaxMetric, "ignore", _CASE_2, 5.0),
127129
(MaxMetric, 2.0, _CASE_2, 5.0),
130+
(MaxMetric, "disable", _CASE_1, torch.tensor(float("nan"))),
128131
(SumMetric, "ignore", _CASE_1, 0.0),
129132
(SumMetric, 2.0, _CASE_1, 10.0),
130133
(SumMetric, "ignore", _CASE_2, 12.0),
131134
(SumMetric, 2.0, _CASE_2, 14.0),
135+
(SumMetric, "disable", _CASE_1, torch.tensor(float("nan"))),
136+
(SumMetric, "disable", _CASE_2, torch.tensor(float("nan"))),
132137
(MeanMetric, "ignore", _CASE_1, torch.tensor([float("nan")])),
133138
(MeanMetric, 2.0, _CASE_1, 2.0),
134139
(MeanMetric, "ignore", _CASE_2, 3.0),
135140
(MeanMetric, 2.0, _CASE_2, 2.8),
141+
(MeanMetric, "disable", _CASE_1, torch.tensor(float("nan"))),
142+
(MeanMetric, "disable", _CASE_2, torch.tensor(float("nan"))),
136143
(CatMetric, "ignore", _CASE_1, []),
137144
(CatMetric, 2.0, _CASE_1, torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0])),
138145
(CatMetric, "ignore", _CASE_2, torch.tensor([1.0, 2.0, 4.0, 5.0])),
139146
(CatMetric, 2.0, _CASE_2, torch.tensor([1.0, 2.0, 2.0, 4.0, 5.0])),
140147
(CatMetric, "ignore", torch.zeros(5), torch.zeros(5)),
148+
(CatMetric, "disable", _CASE_1, _CASE_1),
141149
],
142150
)
143151
def test_nan_expected(metric_class, nan_strategy, value, expected):
144152
"""Test that nan values are handled correctly."""
145-
metric = metric_class(nan_strategy=nan_strategy)
153+
metric: Metric = metric_class(nan_strategy=nan_strategy)
146154
metric.update(value.clone())
147155
out = metric.compute()
148156
assert np.allclose(out, expected, equal_nan=True)

0 commit comments

Comments
 (0)