16
16
17
17
import torch
18
18
from torch import Tensor
19
+ from typing_extensions import Literal
19
20
20
21
from torchmetrics .metric import Metric
21
22
from torchmetrics .utilities import rank_zero_warn
@@ -38,14 +39,15 @@ class BaseAggregator(Metric):
38
39
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
39
40
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
40
41
- ``'ignore'``: all `nan` values are silently removed
42
+ - ``'disable'``: disable all `nan` checks
41
43
- a float: if a float is provided will impute any `nan` values with this value
42
44
43
45
state_name: name of the metric state
44
46
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
45
47
46
48
Raises:
47
49
ValueError:
48
- If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
50
+ If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
49
51
50
52
"""
51
53
@@ -57,12 +59,12 @@ def __init__(
57
59
self ,
58
60
fn : Union [Callable , str ],
59
61
default_value : Union [Tensor , list ],
60
- nan_strategy : Union [str , float ] = "error" ,
62
+ nan_strategy : Union [Literal [ "error" , "warn" , "ignore" , "disable" ] , float ] = "error" ,
61
63
state_name : str = "value" ,
62
64
** kwargs : Any ,
63
65
) -> None :
64
66
super ().__init__ (** kwargs )
65
- allowed_nan_strategy = ("error" , "warn" , "ignore" )
67
+ allowed_nan_strategy = ("error" , "warn" , "ignore" , "disable" )
66
68
if nan_strategy not in allowed_nan_strategy and not isinstance (nan_strategy , float ):
67
69
raise ValueError (
68
70
f"Arg `nan_strategy` should either be a float or one of { allowed_nan_strategy } but got { nan_strategy } ."
@@ -81,26 +83,28 @@ def _cast_and_nan_check_input(
81
83
if weight is not None and not isinstance (weight , Tensor ):
82
84
weight = torch .as_tensor (weight , dtype = self .dtype , device = self .device )
83
85
84
- nans = torch .isnan (x )
85
- if weight is not None :
86
- nans_weight = torch .isnan (weight )
86
+ if self .nan_strategy != "disable" :
87
+ nans = torch .isnan (x )
88
+ if weight is not None :
89
+ nans_weight = torch .isnan (weight )
90
+ else :
91
+ nans_weight = torch .zeros_like (nans ).bool ()
92
+ weight = torch .ones_like (x )
93
+ if nans .any () or nans_weight .any ():
94
+ if self .nan_strategy == "error" :
95
+ raise RuntimeError ("Encountered `nan` values in tensor" )
96
+ if self .nan_strategy in ("ignore" , "warn" ):
97
+ if self .nan_strategy == "warn" :
98
+ rank_zero_warn ("Encountered `nan` values in tensor. Will be removed." , UserWarning )
99
+ x = x [~ (nans | nans_weight )]
100
+ weight = weight [~ (nans | nans_weight )]
101
+ else :
102
+ if not isinstance (self .nan_strategy , float ):
103
+ raise ValueError (f"`nan_strategy` shall be float but you pass { self .nan_strategy } " )
104
+ x [nans | nans_weight ] = self .nan_strategy
105
+ weight [nans | nans_weight ] = 1
87
106
else :
88
- nans_weight = torch .zeros_like (nans ).bool ()
89
107
weight = torch .ones_like (x )
90
- if nans .any () or nans_weight .any ():
91
- if self .nan_strategy == "error" :
92
- raise RuntimeError ("Encountered `nan` values in tensor" )
93
- if self .nan_strategy in ("ignore" , "warn" ):
94
- if self .nan_strategy == "warn" :
95
- rank_zero_warn ("Encountered `nan` values in tensor. Will be removed." , UserWarning )
96
- x = x [~ (nans | nans_weight )]
97
- weight = weight [~ (nans | nans_weight )]
98
- else :
99
- if not isinstance (self .nan_strategy , float ):
100
- raise ValueError (f"`nan_strategy` shall be float but you pass { self .nan_strategy } " )
101
- x [nans | nans_weight ] = self .nan_strategy
102
- weight [nans | nans_weight ] = self .nan_strategy
103
-
104
108
return x .to (self .dtype ), weight .to (self .dtype )
105
109
106
110
def update (self , value : Union [float , Tensor ]) -> None :
@@ -128,13 +132,14 @@ class MaxMetric(BaseAggregator):
128
132
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
129
133
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
130
134
- ``'ignore'``: all `nan` values are silently removed
135
+ - ``'disable'``: disable all `nan` checks
131
136
- a float: if a float is provided will impute any `nan` values with this value
132
137
133
138
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
134
139
135
140
Raises:
136
141
ValueError:
137
- If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
142
+ If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
138
143
139
144
Example:
140
145
>>> from torch import tensor
@@ -152,7 +157,7 @@ class MaxMetric(BaseAggregator):
152
157
153
158
def __init__ (
154
159
self ,
155
- nan_strategy : Union [str , float ] = "warn" ,
160
+ nan_strategy : Union [Literal [ "error" , "warn" , "ignore" , "disable" ] , float ] = "warn" ,
156
161
** kwargs : Any ,
157
162
) -> None :
158
163
super ().__init__ (
@@ -233,13 +238,14 @@ class MinMetric(BaseAggregator):
233
238
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
234
239
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
235
240
- ``'ignore'``: all `nan` values are silently removed
241
+ - ``'disable'``: disable all `nan` checks
236
242
- a float: if a float is provided will impute any `nan` values with this value
237
243
238
244
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
239
245
240
246
Raises:
241
247
ValueError:
242
- If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
248
+ If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
243
249
244
250
Example:
245
251
>>> from torch import tensor
@@ -257,7 +263,7 @@ class MinMetric(BaseAggregator):
257
263
258
264
def __init__ (
259
265
self ,
260
- nan_strategy : Union [str , float ] = "warn" ,
266
+ nan_strategy : Union [Literal [ "error" , "warn" , "ignore" , "disable" ] , float ] = "warn" ,
261
267
** kwargs : Any ,
262
268
) -> None :
263
269
super ().__init__ (
@@ -338,13 +344,14 @@ class SumMetric(BaseAggregator):
338
344
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
339
345
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
340
346
- ``'ignore'``: all `nan` values are silently removed
347
+ - ``'disable'``: disable all `nan` checks
341
348
- a float: if a float is provided will impute any `nan` values with this value
342
349
343
350
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
344
351
345
352
Raises:
346
353
ValueError:
347
- If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
354
+ If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
348
355
349
356
Example:
350
357
>>> from torch import tensor
@@ -361,7 +368,7 @@ class SumMetric(BaseAggregator):
361
368
362
369
def __init__ (
363
370
self ,
364
- nan_strategy : Union [str , float ] = "warn" ,
371
+ nan_strategy : Union [Literal [ "error" , "warn" , "ignore" , "disable" ] , float ] = "warn" ,
365
372
** kwargs : Any ,
366
373
) -> None :
367
374
super ().__init__ (
@@ -443,13 +450,14 @@ class CatMetric(BaseAggregator):
443
450
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
444
451
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
445
452
- ``'ignore'``: all `nan` values are silently removed
453
+ - ``'disable'``: disable all `nan` checks
446
454
- a float: if a float is provided will impute any `nan` values with this value
447
455
448
456
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
449
457
450
458
Raises:
451
459
ValueError:
452
- If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
460
+ If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
453
461
454
462
Example:
455
463
>>> from torch import tensor
@@ -466,7 +474,7 @@ class CatMetric(BaseAggregator):
466
474
467
475
def __init__ (
468
476
self ,
469
- nan_strategy : Union [str , float ] = "warn" ,
477
+ nan_strategy : Union [Literal [ "error" , "warn" , "ignore" , "disable" ] , float ] = "warn" ,
470
478
** kwargs : Any ,
471
479
) -> None :
472
480
super ().__init__ ("cat" , [], nan_strategy , ** kwargs )
@@ -505,17 +513,18 @@ class MeanMetric(BaseAggregator):
505
513
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated (weighted) mean over all inputs received
506
514
507
515
Args:
508
- nan_strategy: options:
516
+ nan_strategy: options:
509
517
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
510
518
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
511
519
- ``'ignore'``: all `nan` values are silently removed
520
+ - ``'disable'``: disable all `nan` checks
512
521
- a float: if a float is provided will impute any `nan` values with this value
513
522
514
523
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
515
524
516
525
Raises:
517
526
ValueError:
518
- If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
527
+ If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
519
528
520
529
Example:
521
530
>>> from torchmetrics.aggregation import MeanMetric
@@ -532,7 +541,7 @@ class MeanMetric(BaseAggregator):
532
541
533
542
def __init__ (
534
543
self ,
535
- nan_strategy : Union [str , float ] = "warn" ,
544
+ nan_strategy : Union [Literal [ "error" , "warn" , "ignore" , "disable" ] , float ] = "warn" ,
536
545
** kwargs : Any ,
537
546
) -> None :
538
547
super ().__init__ (
@@ -544,22 +553,24 @@ def __init__(
544
553
)
545
554
self .add_state ("weight" , default = torch .tensor (0.0 , dtype = torch .get_default_dtype ()), dist_reduce_fx = "sum" )
546
555
547
- def update (self , value : Union [float , Tensor ], weight : Union [float , Tensor ] = 1.0 ) -> None :
556
+ def update (self , value : Union [float , Tensor ], weight : Union [float , Tensor , None ] = None ) -> None :
548
557
"""Update state with data.
549
558
550
559
Args:
551
560
value: Either a float or tensor containing data. Additional tensor
552
561
dimensions will be flattened
553
562
weight: Either a float or tensor containing weights for calculating
554
563
the average. Shape of weight should be able to broadcast with
555
- the shape of `value`. Default to `1.0` corresponding to simple
564
+ the shape of `value`. Default to None corresponding to simple
556
565
harmonic average.
557
566
558
567
"""
559
568
# broadcast weight to value shape
560
569
if not isinstance (value , Tensor ):
561
570
value = torch .as_tensor (value , dtype = self .dtype , device = self .device )
562
- if weight is not None and not isinstance (weight , Tensor ):
571
+ if weight is None :
572
+ weight = torch .ones_like (value )
573
+ elif not isinstance (weight , Tensor ):
563
574
weight = torch .as_tensor (weight , dtype = self .dtype , device = self .device )
564
575
weight = torch .broadcast_to (weight , value .shape )
565
576
value , weight = self ._cast_and_nan_check_input (value , weight )
@@ -631,18 +642,18 @@ class RunningMean(Running):
631
642
- ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received
632
643
633
644
Args:
634
- window: The size of the running window.
635
645
nan_strategy: options:
636
646
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
637
647
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
638
648
- ``'ignore'``: all `nan` values are silently removed
649
+ - ``'disable'``: disable all `nan` checks
639
650
- a float: if a float is provided will impute any `nan` values with this value
640
651
641
652
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
642
653
643
654
Raises:
644
655
ValueError:
645
- If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
656
+ If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
646
657
647
658
Example:
648
659
>>> from torch import tensor
@@ -665,7 +676,7 @@ class RunningMean(Running):
665
676
def __init__ (
666
677
self ,
667
678
window : int = 5 ,
668
- nan_strategy : Union [str , float ] = "warn" ,
679
+ nan_strategy : Union [Literal [ "error" , "warn" , "ignore" , "disable" ] , float ] = "warn" ,
669
680
** kwargs : Any ,
670
681
) -> None :
671
682
super ().__init__ (base_metric = MeanMetric (nan_strategy = nan_strategy , ** kwargs ), window = window )
@@ -693,13 +704,14 @@ class RunningSum(Running):
693
704
- ``'error'``: if any `nan` values are encountered will give a RuntimeError
694
705
- ``'warn'``: if any `nan` values are encountered will give a warning and continue
695
706
- ``'ignore'``: all `nan` values are silently removed
707
+ - ``'disable'``: disable all `nan` checks
696
708
- a float: if a float is provided will impute any `nan` values with this value
697
709
698
710
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
699
711
700
712
Raises:
701
713
ValueError:
702
- If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore`` or a float
714
+ If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
703
715
704
716
Example:
705
717
>>> from torch import tensor
@@ -722,7 +734,7 @@ class RunningSum(Running):
722
734
def __init__ (
723
735
self ,
724
736
window : int = 5 ,
725
- nan_strategy : Union [str , float ] = "warn" ,
737
+ nan_strategy : Union [Literal [ "error" , "warn" , "ignore" , "disable" ] , float ] = "warn" ,
726
738
** kwargs : Any ,
727
739
) -> None :
728
740
super ().__init__ (base_metric = SumMetric (nan_strategy = nan_strategy , ** kwargs ), window = window )
0 commit comments