Skip to content

Commit 1a1087e

Browse files
authored
cleanups on losses and dataset tests (tinygrad#9538)
1 parent 8cbe400 commit 1a1087e

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

examples/mlperf/losses.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@ def dice_ce_loss(pred, tgt):
77
return (dice + ce) / 2
88

99
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float=0.25, gamma:float=2.0, reduction:str="none") -> Tensor:
10-
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
11-
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
12-
p_t = p * tgt + (1 - p) * (1 - tgt)
13-
loss = ce_loss * ((1 - p_t) ** gamma)
10+
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
11+
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
12+
p_t = p * tgt + (1 - p) * (1 - tgt)
13+
loss = ce_loss * ((1 - p_t) ** gamma)
1414

15-
if alpha >= 0:
16-
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
17-
loss = loss * alpha_t
15+
if alpha >= 0:
16+
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
17+
loss = loss * alpha_t
1818

19-
if reduction == "mean": loss = loss.mean()
20-
elif reduction == "sum": loss = loss.sum()
21-
return loss
19+
if reduction == "mean": loss = loss.mean()
20+
elif reduction == "sum": loss = loss.sum()
21+
return loss
2222

2323
def l1_loss(pred:Tensor, tgt:Tensor, reduction:str="none") -> Tensor:
2424
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"

test/external/external_test_datasets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _create_tinygrad_dataloader(self, preproc_pth, val, batch_size=1, shuffle=Fa
6060
if use_old_dataloader:
6161
dataset = iterate(list(Path(tempfile.gettempdir()).glob("case_*")), preprocessed_dir=preproc_pth, val=val, shuffle=shuffle, bs=batch_size)
6262
else:
63-
dataset = iter(batch_load_unet3d(preproc_pth, batch_size=batch_size, val=val, shuffle=shuffle, seed=seed))
63+
dataset = batch_load_unet3d(preproc_pth, batch_size=batch_size, val=val, shuffle=shuffle, seed=seed)
6464

6565
return iter(dataset)
6666

0 commit comments

Comments
 (0)