forked from ZhengPeng7/BiRefNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
175 lines (160 loc) · 10.3 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
import random
import numpy as np
import cv2
from tqdm import tqdm
from PIL import Image
from torch.utils import data
from torchvision import transforms
from image_proc import preproc
from config import Config
from utils import path_to_image
Image.MAX_IMAGE_PIXELS = None # remove DecompressionBombWarning
config = Config()
_class_labels_TR_sorted = (
'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, '
'BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, '
'CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, '
'Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, '
'Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, '
'Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, '
'KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, '
'Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, '
'OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, '
'RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, '
'ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, '
'Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, '
'TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, '
'UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht'
)
class_labels_TR_sorted = _class_labels_TR_sorted.split(', ')
class MyData(data.Dataset):
def __init__(self, datasets, data_size, is_train=True):
# data_size is None when using dynamic_size or data_size is manually set to None (for inference in the original size).
self.is_train = is_train
self.data_size = data_size
self.load_all = config.load_all
self.device = config.device
valid_extensions = ['.png', '.jpg', '.PNG', '.JPG', '.JPEG']
if self.is_train and config.auxiliary_classification:
self.cls_name2id = {_name: _id for _id, _name in enumerate(class_labels_TR_sorted)}
self.transform_image = transforms.Compose([
transforms.Resize(config.size[::-1]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
][self.load_all or self.data_size is None::])
self.transform_label = transforms.Compose([
transforms.Resize(config.size[::-1]),
transforms.ToTensor(),
][self.load_all or self.data_size is None::])
dataset_root = os.path.join(config.data_root_dir, config.task)
# datasets can be a list of different datasets for training on combined sets.
self.image_paths = []
for dataset in datasets.split('+'):
image_root = os.path.join(dataset_root, dataset, 'im')
self.image_paths += [os.path.join(image_root, p) for p in os.listdir(image_root) if any(p.endswith(ext) for ext in valid_extensions)]
self.label_paths = []
for p in self.image_paths:
for ext in valid_extensions:
## 'im' and 'gt' may need modifying
p_gt = p.replace('/im/', '/gt/')[:-(len(p.split('.')[-1])+1)] + ext
file_exists = False
if os.path.exists(p_gt):
self.label_paths.append(p_gt)
file_exists = True
break
if not file_exists:
print('Not exists:', p_gt)
if len(self.label_paths) != len(self.image_paths):
set_image_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.image_paths])
set_label_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.label_paths])
print('Path diff:', set_image_paths - set_label_paths)
raise ValueError(f"There are different numbers of images ({len(self.label_paths)}) and labels ({len(self.image_paths)})")
if self.load_all:
self.images_loaded, self.labels_loaded = [], []
self.class_labels_loaded = []
# for image_path, label_path in zip(self.image_paths, self.label_paths):
for image_path, label_path in tqdm(zip(self.image_paths, self.label_paths), total=len(self.image_paths)):
_image = path_to_image(image_path, size=self.data_size, color_type='rgb')
_label = path_to_image(label_path, size=self.data_size, color_type='gray')
self.images_loaded.append(_image)
self.labels_loaded.append(_label)
self.class_labels_loaded.append(
self.cls_name2id[label_path.split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1
)
def __getitem__(self, index):
if self.load_all:
image = self.images_loaded[index]
label = self.labels_loaded[index]
class_label = self.class_labels_loaded[index] if self.is_train and config.auxiliary_classification else -1
else:
image = path_to_image(self.image_paths[index], size=self.data_size, color_type='rgb')
label = path_to_image(self.label_paths[index], size=self.data_size, color_type='gray')
class_label = self.cls_name2id[self.label_paths[index].split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1
# loading image and label
if self.is_train:
if config.background_color_synthesis:
image.putalpha(label)
array_image = np.array(image)
array_foreground = array_image[:, :, :3].astype(np.float32)
array_mask = (array_image[:, :, 3:] / 255).astype(np.float32)
array_background = np.zeros_like(array_foreground)
choice = random.random()
if choice < 0.4:
# Black/Gray/White backgrounds
array_background[:, :, :] = random.randint(0, 255)
elif choice < 0.8:
# Background color that similar to the foreground object. Hard negative samples.
foreground_pixel_number = np.sum(array_mask > 0)
color_foreground_mean = np.mean(array_foreground * array_mask, axis=(0, 1)) * (np.prod(array_foreground.shape[:2]) / foreground_pixel_number)
color_up_or_down = random.choice((-1, 1))
# Up or down for 20% range from 255 or 0, respectively.
color_foreground_mean += (255 - color_foreground_mean if color_up_or_down == 1 else color_foreground_mean) * (random.random() * 0.2) * color_up_or_down
array_background[:, :, :] = color_foreground_mean
else:
# Any color
for idx_channel in range(3):
array_background[:, :, idx_channel] = random.randint(0, 255)
array_foreground_background = array_foreground * array_mask + array_background * (1 - array_mask)
image = Image.fromarray(array_foreground_background.astype(np.uint8))
image, label = preproc(image, label, preproc_methods=config.preproc_methods)
# else:
# if _label.shape[0] > 2048 or _label.shape[1] > 2048:
# _image = cv2.resize(_image, (2048, 2048), interpolation=cv2.INTER_LINEAR)
# _label = cv2.resize(_label, (2048, 2048), interpolation=cv2.INTER_LINEAR)
# At present, we use fixed sizes in inference, instead of consistent dynamic size with training.
if self.is_train:
if config.dynamic_size is None:
image, label = self.transform_image(image), self.transform_label(label)
else:
size_div_32 = (int(image.size[0] // 32 * 32), int(image.size[1] // 32 * 32))
if image.size != size_div_32:
image = image.resize(size_div_32)
label = label.resize(size_div_32)
image, label = self.transform_image(image), self.transform_label(label)
if self.is_train:
return image, label, class_label
else:
return image, label, self.label_paths[index]
def __len__(self):
return len(self.image_paths)
def custom_collate_fn(batch):
if config.dynamic_size:
dynamic_size = tuple(sorted(config.dynamic_size))
dynamic_size_batch = (random.randint(dynamic_size[0][0], dynamic_size[0][1]) // 32 * 32, random.randint(dynamic_size[1][0], dynamic_size[1][1]) // 32 * 32) # select a value randomly in the range of [dynamic_size[0/1][0], dynamic_size[0/1][1]].
data_size = dynamic_size_batch
else:
data_size = config.size
new_batch = []
transform_image = transforms.Compose([
transforms.Resize(data_size[::-1]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
transform_label = transforms.Compose([
transforms.Resize(data_size[::-1]),
transforms.ToTensor(),
])
for image, label, class_label in batch:
new_batch.append((transform_image(image), transform_label(label), class_label))
return data._utils.collate.default_collate(new_batch)