Skip to content

Commit 2387a2d

Browse files
adds weight parameter to dice and lovasz_softmax losses (kornia#2879)
* adds weight parameter to dice and lovasz_softmax losses, similarly how we have focal loss * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d1a1cc0 commit 2387a2d

File tree

4 files changed

+71
-10
lines changed

4 files changed

+71
-10
lines changed

kornia/losses/dice.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Optional
4+
35
import torch
46
from torch import nn
57

@@ -12,7 +14,9 @@
1214
# https://github.com/Lightning-AI/metrics/blob/v0.11.3/src/torchmetrics/functional/classification/dice.py#L66-L207
1315

1416

15-
def dice_loss(pred: Tensor, target: Tensor, average: str = "micro", eps: float = 1e-8) -> Tensor:
17+
def dice_loss(
18+
pred: Tensor, target: Tensor, average: str = "micro", eps: float = 1e-8, weight: Optional[Tensor] = None
19+
) -> Tensor:
1620
r"""Criterion that computes Sørensen-Dice Coefficient loss.
1721
1822
According to [1], we compute the Sørensen-Dice Coefficient as follows:
@@ -43,6 +47,7 @@ def dice_loss(pred: Tensor, target: Tensor, average: str = "micro", eps: float =
4347
- ``'micro'`` [default]: Calculate the loss across all classes.
4448
- ``'macro'``: Calculate the loss for each class separately and average the metrics across classes.
4549
eps: Scalar to enforce numerical stabiliy.
50+
weight: weights for classes with shape :math:`(num\_of\_classes,)`.
4651
4752
Return:
4853
One-element tensor of the computed loss.
@@ -64,7 +69,7 @@ def dice_loss(pred: Tensor, target: Tensor, average: str = "micro", eps: float =
6469

6570
if not pred.device == target.device:
6671
raise ValueError(f"pred and target must be in the same device. Got: {pred.device} and {target.device}")
67-
72+
num_of_classes = pred.shape[1]
6873
possible_average = {"micro", "macro"}
6974
KORNIA_CHECK(average in possible_average, f"The `average` has to be one of {possible_average}. Got: {average}")
7075

@@ -80,6 +85,18 @@ def dice_loss(pred: Tensor, target: Tensor, average: str = "micro", eps: float =
8085
dims = (1, *dims)
8186

8287
# compute the actual dice score
88+
if weight is not None:
89+
KORNIA_CHECK_IS_TENSOR(weight, "weight must be Tensor or None.")
90+
KORNIA_CHECK(
91+
(weight.shape[0] == num_of_classes and weight.numel() == num_of_classes),
92+
f"weight shape must be (num_of_classes,): ({num_of_classes},), got {weight.shape}",
93+
)
94+
KORNIA_CHECK(
95+
weight.device == pred.device,
96+
f"weight and pred must be in the same device. Got: {weight.device} and {pred.device}",
97+
)
98+
pred_soft = pred_soft * weight
99+
target_one_hot = target_one_hot * weight
83100
intersection = torch.sum(pred_soft * target_one_hot, dims)
84101
cardinality = torch.sum(pred_soft + target_one_hot, dims)
85102

@@ -120,6 +137,7 @@ class DiceLoss(nn.Module):
120137
- ``'micro'`` [default]: Calculate the loss across all classes.
121138
- ``'macro'``: Calculate the loss for each class separately and average the metrics across classes.
122139
eps: Scalar to enforce numerical stabiliy.
140+
weight: weights for classes with shape :math:`(num\_of\_classes,)`.
123141
124142
Shape:
125143
- Pred: :math:`(N, C, H, W)` where C = number of classes.
@@ -135,10 +153,11 @@ class DiceLoss(nn.Module):
135153
>>> output.backward()
136154
"""
137155

138-
def __init__(self, average: str = "micro", eps: float = 1e-8) -> None:
156+
def __init__(self, average: str = "micro", eps: float = 1e-8, weight: Optional[Tensor] = None) -> None:
139157
super().__init__()
140158
self.average = average
141159
self.eps = eps
160+
self.weight = weight
142161

143162
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
144-
return dice_loss(pred, target, self.average, self.eps)
163+
return dice_loss(pred, target, self.average, self.eps, self.weight)

kornia/losses/lovasz_softmax.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from __future__ import annotations
22

3+
from typing import Optional
4+
35
import torch
46
from torch import Tensor, nn
57

6-
from kornia.core.check import KORNIA_CHECK_SHAPE
8+
from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
79

810
# based on:
911
# https://github.com/bermanmaxim/LovaszSoftmax
1012

1113

12-
def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
14+
def lovasz_softmax_loss(pred: Tensor, target: Tensor, weight: Optional[Tensor] = None) -> Tensor:
1315
r"""Criterion that computes a surrogate multi-class intersection-over-union (IoU) loss.
1416
1517
According to [1], we compute the IoU as follows:
@@ -22,7 +24,7 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
2224
2325
Where:
2426
- :math:`X` expects to be the scores of each class.
25-
- :math:`Y` expects to be the binary tensor with the class labels.
27+
- :math:`Y` expects to be the long tensor with the class labels.
2628
2729
the loss, is finally computed as:
2830
@@ -41,6 +43,7 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
4143
pred: logits tensor with shape :math:`(N, C, H, W)` where C = number of classes > 1.
4244
labels: labels tensor with shape :math:`(N, H, W)` where each value
4345
is :math:`0 ≤ targets[i] ≤ C-1`.
46+
weight: weights for classes with shape :math:`(num\_of\_classes,)`.
4447
4548
Return:
4649
a scalar with the computed loss.
@@ -65,6 +68,19 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
6568
if not pred.device == target.device:
6669
raise ValueError(f"pred and target must be in the same device. Got: {pred.device} and {target.device}")
6770

71+
num_of_classes = pred.shape[1]
72+
# compute the actual dice score
73+
if weight is not None:
74+
KORNIA_CHECK_IS_TENSOR(weight, "weight must be Tensor or None.")
75+
KORNIA_CHECK(
76+
(weight.shape[0] == num_of_classes and weight.numel() == num_of_classes),
77+
f"weight shape must be (num_of_classes,): ({num_of_classes},), got {weight.shape}",
78+
)
79+
KORNIA_CHECK(
80+
weight.device == pred.device,
81+
f"weight and pred must be in the same device. Got: {weight.device} and {pred.device}",
82+
)
83+
6884
# flatten pred [B, C, -1] and target [B, -1] and to float
6985
pred_flatten: Tensor = pred.reshape(pred.shape[0], pred.shape[1], -1)
7086
target_flatten: Tensor = target.reshape(target.shape[0], -1).float()
@@ -91,7 +107,7 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
91107
gradient: Tensor = 1.0 - intersection / union
92108
if N > 1:
93109
gradient[..., 1:] = gradient[..., 1:] - gradient[..., :-1]
94-
loss: Tensor = (errors_sorted.relu() * gradient).sum(1).mean()
110+
loss: Tensor = (errors_sorted.relu() * gradient).sum(1).mean() * (1.0 if weight is None else weight[c])
95111
losses.append(loss)
96112
final_loss: Tensor = torch.stack(losses, dim=0).mean()
97113
return final_loss
@@ -129,6 +145,7 @@ class LovaszSoftmaxLoss(nn.Module):
129145
pred: logits tensor with shape :math:`(N, C, H, W)` where C = number of classes > 1.
130146
labels: labels tensor with shape :math:`(N, H, W)` where each value
131147
is :math:`0 ≤ targets[i] ≤ C-1`.
148+
weight: weights for classes with shape :math:`(num\_of\_classes,)`.
132149
133150
Return:
134151
a scalar with the computed loss.
@@ -142,8 +159,9 @@ class LovaszSoftmaxLoss(nn.Module):
142159
>>> output.backward()
143160
"""
144161

145-
def __init__(self) -> None:
162+
def __init__(self, weight: Optional[Tensor] = None) -> None:
146163
super().__init__()
164+
self.weight = weight
147165

148166
def forward(self, pred: Tensor, target: Tensor) -> Tensor:
149-
return lovasz_softmax_loss(pred=pred, target=target)
167+
return lovasz_softmax_loss(pred=pred, target=target, weight=self.weight)

tests/losses/test_dice.py

+11
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,17 @@ def test_averaging_micro(self, device, dtype):
6565
loss = criterion(logits, labels)
6666
self.assert_close(loss, expected_loss, rtol=1e-3, atol=1e-3)
6767

68+
def test_weight(self, device, dtype):
69+
num_classes = 3
70+
eps = 1e-8
71+
logits = torch.zeros(2, num_classes, 4, 1, device=device, dtype=dtype)
72+
labels = torch.zeros(2, 4, 1, device=device, dtype=torch.int64)
73+
expected_loss = torch.tensor([2.0 / 3.0], device=device, dtype=dtype).squeeze()
74+
weight = torch.tensor([0.0, 1.0, 1.0], device=device, dtype=dtype)
75+
criterion = kornia.losses.DiceLoss(average="micro", eps=eps, weight=weight)
76+
loss = criterion(logits, labels)
77+
self.assert_close(loss, expected_loss, rtol=1e-3, atol=1e-3)
78+
6879
def test_averaging_macro(self, device, dtype):
6980
num_classes = 2
7081
eps = 1e-8

tests/losses/test_lovaz_softmax.py

+13
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@ def test_all_ones(self, device, dtype):
6262

6363
self.assert_close(loss, torch.zeros_like(loss), rtol=1e-3, atol=1e-3)
6464

65+
def test_weight(self, device, dtype):
66+
num_classes = 2
67+
# make perfect prediction
68+
# note that softmax(prediction[:, 1]) == 1. softmax(prediction[:, 0]) == 0.
69+
prediction = torch.zeros(2, num_classes, 1, 2, device=device, dtype=dtype)
70+
prediction[:, 0] = 100.0
71+
labels = torch.ones(2, 1, 2, device=device, dtype=torch.int64)
72+
73+
criterion = kornia.losses.LovaszSoftmaxLoss(weight=torch.tensor([1.0, 0.0], device=device, dtype=dtype))
74+
loss = criterion(prediction, labels)
75+
76+
self.assert_close(loss, 0.5 * torch.ones_like(loss), rtol=1e-3, atol=1e-3)
77+
6578
def test_gradcheck(self, device):
6679
num_classes = 4
6780
logits = torch.rand(2, num_classes, 3, 2, device=device, dtype=torch.float64)

0 commit comments

Comments
 (0)