1
1
from tinygrad import Tensor
2
+ from test .external .mlperf_retinanet .focal_loss import sigmoid_focal_loss as ref_sigmoid_focal_loss
2
3
from test .external .mlperf_unet3d .dice import DiceCELoss
3
- from examples .mlperf .losses import dice_ce_loss
4
+ from examples .mlperf .losses import dice_ce_loss , sigmoid_focal_loss , l1_loss
4
5
5
6
import numpy as np
6
7
import torch
7
8
import unittest
8
9
9
10
class ExternalTestLosses (unittest .TestCase ):
10
- def _test_losses (self , tinygrad_metrics , orig_metrics , pred , label ):
11
- tinygrad_metrics_res = tinygrad_metrics (Tensor (pred ), Tensor (label )).numpy ()
12
- orig_metrics_res = orig_metrics (torch .from_numpy (pred ), torch .from_numpy (label )).numpy ()
13
- np .testing .assert_allclose (tinygrad_metrics_res , orig_metrics_res , atol = 1e-4 )
11
+ def setUp (self ):
12
+ np .random .seed (42 )
14
13
15
- def test_dice_ce (self ):
14
+ def _assert_loss (self , pred , tgt , tinygrad_metrics , ref_metrics , rtol = 1e-07 , atol = 0 , ** kwargs ):
15
+ tinygrad_metrics_res = tinygrad_metrics (Tensor (pred ), Tensor (tgt ), ** kwargs )
16
+ ref_metrics_res = ref_metrics (torch .from_numpy (pred ), torch .from_numpy (tgt ), ** kwargs )
17
+ np .testing .assert_allclose (tinygrad_metrics_res .numpy (), ref_metrics_res .numpy (), rtol = rtol , atol = atol )
18
+
19
+ def test_dice_ce_loss (self ):
16
20
pred , label = np .random .rand (1 , 3 , 128 , 128 , 128 ).astype (np .float32 ), np .ones ((1 , 1 , 128 , 128 , 128 )).astype (np .uint8 )
17
- self ._test_losses (dice_ce_loss , DiceCELoss (True , True , "NCDHW" , False ), pred , label )
21
+ tinygrad_metrics_res , ref_metrics_res = dice_ce_loss , DiceCELoss (True , True , "NCDHW" , False )
22
+ self ._assert_loss (pred , label , tinygrad_metrics_res , ref_metrics_res , atol = 1e-4 )
23
+
24
+ def test_sigmoid_focal_loss (self ):
25
+ def _apply_logit (p ): return np .log (p / (1 - p ))
26
+ pred , tgt = _apply_logit (np .random .rand (5 ,2 ).astype (np .float32 )), np .random .randint (0 , 2 , (5 , 2 )).astype (np .float32 )
27
+ for reduction in ["mean" , "sum" , "none" ]:
28
+ for alpha , gamma in zip ([- 1 , 0.58 ], [0 , 2 ]):
29
+ self ._assert_loss (pred , tgt , sigmoid_focal_loss , ref_sigmoid_focal_loss , rtol = 1e-4 , alpha = alpha , gamma = gamma , reduction = reduction )
30
+
31
+ def test_l1_loss (self ):
32
+ N , C , H , W = 3 , 4 , 5 , 6
33
+ shapes = ((N , C ), (N , C , H ), (N , C , H , W ))
34
+ for reduction in ["mean" , "sum" , "none" ]:
35
+ for shape in shapes :
36
+ pred , tgt = np .random .randint (shape ).astype (np .float32 ), np .random .randint (shape )
37
+ self ._assert_loss (pred , tgt , l1_loss , torch .nn .functional .l1_loss , reduction = reduction )
18
38
19
39
if __name__ == '__main__' :
20
40
unittest .main ()
0 commit comments