@@ -83,7 +83,7 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor, weight: Optional[Tensor] =
83
83
84
84
# flatten pred [B, C, -1] and target [B, -1] and to float
85
85
pred_flatten : Tensor = pred .reshape (pred .shape [0 ], pred .shape [1 ], - 1 )
86
- target_flatten : Tensor = target .reshape (target .shape [0 ], - 1 ). float ()
86
+ target_flatten : Tensor = target .reshape (target .shape [0 ], - 1 )
87
87
88
88
# get shapes
89
89
B , C , N = pred_flatten .shape
@@ -92,24 +92,24 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor, weight: Optional[Tensor] =
92
92
pred_soft : Tensor = pred_flatten .softmax (1 )
93
93
94
94
# compute actual loss
95
- losses : list [ Tensor ] = []
96
- batch_index : Tensor = torch .arange ( B , device = pred . device ). reshape ( - 1 , 1 ). repeat ( 1 , N ). reshape ( - 1 )
97
- for c in range ( C ):
98
- foreground : Tensor = 1.0 * ( target_flatten == c )
99
- class_pred : Tensor = pred_soft [:, c ]
100
- errors = ( class_pred - foreground ). abs ( )
101
- errors_sorted , permutation = torch . sort ( errors , dim = 1 , descending = True )
102
- target_sorted : Tensor = target_flatten [ batch_index , permutation . view ( - 1 )]
103
- target_sorted = target_sorted . view ( B , N )
104
- target_sorted_sum : Tensor = target_sorted . sum ( 1 , keepdim = True )
105
- intersection : Tensor = target_sorted_sum - target_sorted . cumsum ( 1 )
106
- union : Tensor = target_sorted_sum + ( 1.0 - target_sorted ). cumsum ( 1 )
107
- gradient : Tensor = 1.0 - intersection / union
108
- if N > 1 :
109
- gradient [..., 1 :] = gradient [..., 1 :] - gradient [..., : - 1 ]
110
- loss : Tensor = ( errors_sorted . relu () * gradient ). sum ( 1 ). mean () * ( 1.0 if weight is None else weight [ c ])
111
- losses . append ( loss )
112
- final_loss : Tensor = torch . stack ( losses , dim = 0 ) .mean ()
95
+ foreground : Tensor = (
96
+ torch .nn . functional . one_hot ( target_flatten . to ( torch . int64 ), num_classes = C ). permute ( 0 , 2 , 1 ). to ( pred . dtype )
97
+ )
98
+ errors : Tensor = ( pred_soft - foreground ). abs ( )
99
+ errors_sorted , permutations = torch . sort ( errors , dim = 2 , descending = True )
100
+ batch_index = torch . arange ( B , device = pred . device ). unsqueeze ( 1 ). unsqueeze ( 2 ). expand ( B , C , N )
101
+ target_sorted = target_flatten [ batch_index , permutations ]
102
+ target_sorted_sum = target_sorted . sum ( 2 , keepdim = True )
103
+ intersection = target_sorted_sum - target_sorted . cumsum ( 2 )
104
+ union = target_sorted_sum + ( 1.0 - target_sorted ). cumsum ( 2 )
105
+ gradient = 1.0 - intersection / union
106
+ if N > 1 :
107
+ gradient [..., 1 :] = gradient [..., 1 :] - gradient [..., : - 1 ]
108
+ weighted_errors = errors_sorted * gradient
109
+ loss_per_class = weighted_errors . sum ( 2 ). mean ( 0 )
110
+ if weight is not None :
111
+ loss_per_class *= weight
112
+ final_loss : Tensor = loss_per_class .mean ()
113
113
return final_loss
114
114
115
115
0 commit comments