@@ -7,18 +7,18 @@ def dice_ce_loss(pred, tgt):
7
7
return (dice + ce ) / 2
8
8
9
9
def sigmoid_focal_loss (pred :Tensor , tgt :Tensor , alpha :float = 0.25 , gamma :float = 2.0 , reduction :str = "none" ) -> Tensor :
10
- assert reduction in ["mean" , "sum" , "none" ], f"unsupported reduction { reduction } "
11
- p , ce_loss = pred .sigmoid (), pred .binary_crossentropy_logits (tgt , reduction = "none" )
12
- p_t = p * tgt + (1 - p ) * (1 - tgt )
13
- loss = ce_loss * ((1 - p_t ) ** gamma )
10
+ assert reduction in ["mean" , "sum" , "none" ], f"unsupported reduction { reduction } "
11
+ p , ce_loss = pred .sigmoid (), pred .binary_crossentropy_logits (tgt , reduction = "none" )
12
+ p_t = p * tgt + (1 - p ) * (1 - tgt )
13
+ loss = ce_loss * ((1 - p_t ) ** gamma )
14
14
15
- if alpha >= 0 :
16
- alpha_t = alpha * tgt + (1 - alpha ) * (1 - tgt )
17
- loss = loss * alpha_t
15
+ if alpha >= 0 :
16
+ alpha_t = alpha * tgt + (1 - alpha ) * (1 - tgt )
17
+ loss = loss * alpha_t
18
18
19
- if reduction == "mean" : loss = loss .mean ()
20
- elif reduction == "sum" : loss = loss .sum ()
21
- return loss
19
+ if reduction == "mean" : loss = loss .mean ()
20
+ elif reduction == "sum" : loss = loss .sum ()
21
+ return loss
22
22
23
23
def l1_loss (pred :Tensor , tgt :Tensor , reduction :str = "none" ) -> Tensor :
24
24
assert reduction in ["mean" , "sum" , "none" ], f"unsupported reduction { reduction } "
0 commit comments