-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
82 lines (61 loc) · 2.07 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import cv2
import numpy as np
import torchvision
import warnings
warnings.filterwarnings("ignore")
def open_img(path):
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def get_encoder(model, pretrained=True):
if model == "resnet18":
encoder = torchvision.models.resnet18(pretrained=pretrained)
elif model == "resnet34":
encoder = torchvision.models.resnet34(pretrained=pretrained)
elif model == "resnet50":
encoder = torchvision.models.resnet50(pretrained=pretrained)
elif model == "resnext50":
encoder = torchvision.models.resnext50_32x4d(pretrained=pretrained)
elif model == "resnext101":
encoder = torchvision.models.resnext101_32x8d(pretrained=pretrained)
if model in ["resnet18", "resnet34"]:
model = "resnet18-34"
else:
model = "resnet50-101"
filters_dict = {
"resnet18-34": [512, 512, 256, 128, 64],
"resnet50-101": [2048, 2048, 1024, 512, 256]
}
return encoder, filters_dict[model]
def invert_mask(mask):
return np.bitwise_not(mask.astype(bool)).astype(int)
class DropClusters:
'''
Post processing of each predicted mask,
components with lesser number of pixels
than `min_size` are ignored
'''
@classmethod
def drop(self, mask, min_size=50 * 50):
self.min_size = min_size
for i in range(2):
mask = self.filt_invert(mask)
return mask
@classmethod
def filt_invert(self, mask):
num_component, component = cv2.connectedComponents(mask.astype(np.uint8))
predictions = np.zeros(mask.shape[:2], np.int)
for c in range(1, num_component):
p = (component == c)
if p.sum() > self.min_size:
predictions[p] = 1
inverse_mask = invert_mask(predictions)
return inverse_mask
def load_train_config(path="train_config.yaml"):
with open(path) as f:
data = yaml.load(f)
return data
def torch2np(outputs):
outputs = outputs.squeeze(0).permute(1, 2, 0).numpy()
return outputs