Skip to content

Commit 8cbe400

Browse files
authored
RetinaNet losses (tinygrad#9536)
* add sigmoid_focal_loss and l1_loss * update ref implementation comment
1 parent e638918 commit 8cbe400

File tree

3 files changed

+101
-7
lines changed

3 files changed

+101
-7
lines changed

examples/mlperf/losses.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,29 @@
11
from examples.mlperf.metrics import dice_score
2+
from tinygrad import Tensor
23

34
def dice_ce_loss(pred, tgt):
45
ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt.squeeze(1))
56
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
67
return (dice + ce) / 2
8+
9+
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float=0.25, gamma:float=2.0, reduction:str="none") -> Tensor:
10+
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
11+
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
12+
p_t = p * tgt + (1 - p) * (1 - tgt)
13+
loss = ce_loss * ((1 - p_t) ** gamma)
14+
15+
if alpha >= 0:
16+
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
17+
loss = loss * alpha_t
18+
19+
if reduction == "mean": loss = loss.mean()
20+
elif reduction == "sum": loss = loss.sum()
21+
return loss
22+
23+
def l1_loss(pred:Tensor, tgt:Tensor, reduction:str="none") -> Tensor:
24+
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
25+
loss = (pred - tgt).abs()
26+
27+
if reduction == "mean": loss = loss.mean()
28+
elif reduction == "sum": loss = loss.sum()
29+
return loss

test/external/external_test_losses.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,40 @@
11
from tinygrad import Tensor
2+
from test.external.mlperf_retinanet.focal_loss import sigmoid_focal_loss as ref_sigmoid_focal_loss
23
from test.external.mlperf_unet3d.dice import DiceCELoss
3-
from examples.mlperf.losses import dice_ce_loss
4+
from examples.mlperf.losses import dice_ce_loss, sigmoid_focal_loss, l1_loss
45

56
import numpy as np
67
import torch
78
import unittest
89

910
class ExternalTestLosses(unittest.TestCase):
10-
def _test_losses(self, tinygrad_metrics, orig_metrics, pred, label):
11-
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).numpy()
12-
orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy()
13-
np.testing.assert_allclose(tinygrad_metrics_res, orig_metrics_res, atol=1e-4)
11+
def setUp(self):
12+
np.random.seed(42)
1413

15-
def test_dice_ce(self):
14+
def _assert_loss(self, pred, tgt, tinygrad_metrics, ref_metrics, rtol=1e-07, atol=0, **kwargs):
15+
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(tgt), **kwargs)
16+
ref_metrics_res = ref_metrics(torch.from_numpy(pred), torch.from_numpy(tgt), **kwargs)
17+
np.testing.assert_allclose(tinygrad_metrics_res.numpy(), ref_metrics_res.numpy(), rtol=rtol, atol=atol)
18+
19+
def test_dice_ce_loss(self):
1620
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
17-
self._test_losses(dice_ce_loss, DiceCELoss(True, True, "NCDHW", False), pred, label)
21+
tinygrad_metrics_res, ref_metrics_res = dice_ce_loss, DiceCELoss(True, True, "NCDHW", False)
22+
self._assert_loss(pred, label, tinygrad_metrics_res, ref_metrics_res, atol=1e-4)
23+
24+
def test_sigmoid_focal_loss(self):
25+
def _apply_logit(p): return np.log(p / (1 - p))
26+
pred, tgt = _apply_logit(np.random.rand(5,2).astype(np.float32)), np.random.randint(0, 2, (5, 2)).astype(np.float32)
27+
for reduction in ["mean", "sum", "none"]:
28+
for alpha, gamma in zip([-1, 0.58], [0, 2]):
29+
self._assert_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, rtol=1e-4, alpha=alpha, gamma=gamma, reduction=reduction)
30+
31+
def test_l1_loss(self):
32+
N, C, H, W = 3, 4, 5, 6
33+
shapes = ((N, C), (N, C, H), (N, C, H, W))
34+
for reduction in ["mean", "sum", "none"]:
35+
for shape in shapes:
36+
pred, tgt = np.random.randint(shape).astype(np.float32), np.random.randint(shape)
37+
self._assert_loss(pred, tgt, l1_loss, torch.nn.functional.l1_loss, reduction=reduction)
1838

1939
if __name__ == '__main__':
2040
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copied from https://github.com/mlcommons/training/blob/cdd928d4596c142c15a7d86b2eeadbac718c8da2/single_stage_detector/ssd/model/focal_loss.py
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
7+
def sigmoid_focal_loss(
8+
inputs: torch.Tensor,
9+
targets: torch.Tensor,
10+
alpha: float = 0.25,
11+
gamma: float = 2,
12+
reduction: str = "none",
13+
):
14+
"""
15+
Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
16+
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
17+
18+
Args:
19+
inputs: A float tensor of arbitrary shape.
20+
The predictions for each example.
21+
targets: A float tensor with the same shape as inputs. Stores the binary
22+
classification label for each element in inputs
23+
(0 for the negative class and 1 for the positive class).
24+
alpha: (optional) Weighting factor in range (0,1) to balance
25+
positive vs negative examples or -1 for ignore. Default = 0.25
26+
gamma: Exponent of the modulating factor (1 - p_t) to
27+
balance easy vs hard examples.
28+
reduction: 'none' | 'mean' | 'sum'
29+
'none': No reduction will be applied to the output.
30+
'mean': The output will be averaged.
31+
'sum': The output will be summed.
32+
Returns:
33+
Loss tensor with the reduction option applied.
34+
"""
35+
p = torch.sigmoid(inputs)
36+
ce_loss = F.binary_cross_entropy_with_logits(
37+
inputs, targets, reduction="none"
38+
)
39+
p_t = p * targets + (1 - p) * (1 - targets)
40+
loss = ce_loss * ((1 - p_t) ** gamma)
41+
42+
if alpha >= 0:
43+
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
44+
loss = alpha_t * loss
45+
46+
if reduction == "mean":
47+
loss = loss.mean()
48+
elif reduction == "sum":
49+
loss = loss.sum()
50+
51+
return loss

0 commit comments

Comments
 (0)