1
1
from __future__ import annotations
2
2
3
+ from typing import Optional
4
+
3
5
import torch
4
6
from torch import Tensor , nn
5
7
6
- from kornia .core .check import KORNIA_CHECK_SHAPE
8
+ from kornia .core .check import KORNIA_CHECK , KORNIA_CHECK_IS_TENSOR , KORNIA_CHECK_SHAPE
7
9
8
10
# based on:
9
11
# https://github.com/bermanmaxim/LovaszSoftmax
10
12
11
13
12
- def lovasz_softmax_loss (pred : Tensor , target : Tensor ) -> Tensor :
14
+ def lovasz_softmax_loss (pred : Tensor , target : Tensor , weight : Optional [ Tensor ] = None ) -> Tensor :
13
15
r"""Criterion that computes a surrogate multi-class intersection-over-union (IoU) loss.
14
16
15
17
According to [1], we compute the IoU as follows:
@@ -22,7 +24,7 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
22
24
23
25
Where:
24
26
- :math:`X` expects to be the scores of each class.
25
- - :math:`Y` expects to be the binary tensor with the class labels.
27
+ - :math:`Y` expects to be the long tensor with the class labels.
26
28
27
29
the loss, is finally computed as:
28
30
@@ -41,6 +43,7 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
41
43
pred: logits tensor with shape :math:`(N, C, H, W)` where C = number of classes > 1.
42
44
labels: labels tensor with shape :math:`(N, H, W)` where each value
43
45
is :math:`0 ≤ targets[i] ≤ C-1`.
46
+ weight: weights for classes with shape :math:`(num\_of\_classes,)`.
44
47
45
48
Return:
46
49
a scalar with the computed loss.
@@ -65,6 +68,19 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
65
68
if not pred .device == target .device :
66
69
raise ValueError (f"pred and target must be in the same device. Got: { pred .device } and { target .device } " )
67
70
71
+ num_of_classes = pred .shape [1 ]
72
+ # compute the actual dice score
73
+ if weight is not None :
74
+ KORNIA_CHECK_IS_TENSOR (weight , "weight must be Tensor or None." )
75
+ KORNIA_CHECK (
76
+ (weight .shape [0 ] == num_of_classes and weight .numel () == num_of_classes ),
77
+ f"weight shape must be (num_of_classes,): ({ num_of_classes } ,), got { weight .shape } " ,
78
+ )
79
+ KORNIA_CHECK (
80
+ weight .device == pred .device ,
81
+ f"weight and pred must be in the same device. Got: { weight .device } and { pred .device } " ,
82
+ )
83
+
68
84
# flatten pred [B, C, -1] and target [B, -1] and to float
69
85
pred_flatten : Tensor = pred .reshape (pred .shape [0 ], pred .shape [1 ], - 1 )
70
86
target_flatten : Tensor = target .reshape (target .shape [0 ], - 1 ).float ()
@@ -91,7 +107,7 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
91
107
gradient : Tensor = 1.0 - intersection / union
92
108
if N > 1 :
93
109
gradient [..., 1 :] = gradient [..., 1 :] - gradient [..., :- 1 ]
94
- loss : Tensor = (errors_sorted .relu () * gradient ).sum (1 ).mean ()
110
+ loss : Tensor = (errors_sorted .relu () * gradient ).sum (1 ).mean () * ( 1.0 if weight is None else weight [ c ])
95
111
losses .append (loss )
96
112
final_loss : Tensor = torch .stack (losses , dim = 0 ).mean ()
97
113
return final_loss
@@ -129,6 +145,7 @@ class LovaszSoftmaxLoss(nn.Module):
129
145
pred: logits tensor with shape :math:`(N, C, H, W)` where C = number of classes > 1.
130
146
labels: labels tensor with shape :math:`(N, H, W)` where each value
131
147
is :math:`0 ≤ targets[i] ≤ C-1`.
148
+ weight: weights for classes with shape :math:`(num\_of\_classes,)`.
132
149
133
150
Return:
134
151
a scalar with the computed loss.
@@ -142,8 +159,9 @@ class LovaszSoftmaxLoss(nn.Module):
142
159
>>> output.backward()
143
160
"""
144
161
145
- def __init__ (self ) -> None :
162
+ def __init__ (self , weight : Optional [ Tensor ] = None ) -> None :
146
163
super ().__init__ ()
164
+ self .weight = weight
147
165
148
166
def forward (self , pred : Tensor , target : Tensor ) -> Tensor :
149
- return lovasz_softmax_loss (pred = pred , target = target )
167
+ return lovasz_softmax_loss (pred = pred , target = target , weight = self . weight )
0 commit comments