-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
123 lines (90 loc) · 3.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
import albumentations as A
import torch
import torch.nn as nn
import torch.nn.functional as F
from albumentations.pytorch import ToTensorV2
transform = A.Compose(
[
A.Resize(224, 224),
A.Normalize(),
ToTensorV2()
]
)
# Run-Length Encoding Function
def rle_encode(mask):
pixels = mask.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)
# Run-Length Decoding Function
def rle_decode(mask_rle, shape):
s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
for lo, hi in zip(starts, ends):
img[lo:hi] = 1
img = img.reshape(shape)
return img
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, inputs, targets, smooth=1e-5):
# Flatten inputs and targets
inputs = inputs.reshape(-1)
targets = targets.reshape(-1)
intersection = (inputs * targets).sum()
dice_coeff = (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
loss = 1.0 - dice_coeff
return loss
def dice_loss(pred, target):
"""This definition generalize to real valued pred and target vector.
This should be differentiable.
pred: tensor with first dimension as batch
target: tensor with first dimension as batch
"""
pred = torch.sigmoid(pred)
smooth = 1e-7
# have to use contiguous since they may from a torch.view op
iflat = pred.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
A_sum = torch.sum(iflat)
B_sum = torch.sum(tflat)
return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth))
def dice_score(prediction: np.array, ground_truth: np.array, smooth=1e-7) -> float:
'''
Calculate Dice Score between two binary masks.
'''
intersection = np.sum(prediction * ground_truth)
return (2.0 * intersection + smooth) / (np.sum(prediction) + np.sum(ground_truth) + smooth)
def calculate_dice_scores(ground_truth_df, prediction_df, img_shape=(224, 224)):
'''
Calculate Dice scores for a dataset.
'''
# Keep only the rows in the prediction dataframe that have matching img_ids in the ground truth dataframe
prediction_df = prediction_df[prediction_df.iloc[:, 0].isin(ground_truth_df.iloc[:, 0])]
prediction_df.index = range(prediction_df.shape[0])
# Extract the mask_rle columns
pred_mask_rle = prediction_df.iloc[:, 1]
gt_mask_rle = ground_truth_df.iloc[:, 1]
def calculate_dice(pred_rle, gt_rle):
pred_mask = rle_decode(pred_rle, img_shape)
gt_mask = rle_decode(gt_rle, img_shape)
if np.sum(gt_mask) > 0 or np.sum(pred_mask) > 0:
return dice_score(pred_mask, gt_mask)
else:
return None # No valid masks found, return None
dice_scores = Parallel(n_jobs=-1)(
delayed(calculate_dice)(pred_rle, gt_rle) for pred_rle, gt_rle in zip(pred_mask_rle, gt_mask_rle)
)
dice_scores = [score for score in dice_scores if score is not None] # Exclude None values
return np.mean(dice_scores)
def save_model(model, fname):
# save model
model_dir = f'./models/{fname}.pt'
torch.save(model.state_dict(), model_dir)
print(f"Model saved at {model_dir}")