-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_poll.py
159 lines (140 loc) · 7.79 KB
/
predict_poll.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import argparse
from image_complete import auto
import traceback
from sfp import Poller
from predict_common import prediction_to_file, PREDICTION_FORMATS, PREDICTION_FORMAT_GRAYSCALE, load_model, classes_dict
import paddle
from paddleseg.core.predict import preprocess
from paddleseg.core import infer
SUPPORTED_EXTS = [".jpg", ".jpeg"]
""" supported file extensions (lower case). """
def check_image(fname, poller):
"""
Check method that ensures the image is valid.
:param fname: the file to check
:type fname: str
:param poller: the Poller instance that called the method
:type poller: Poller
:return: True if complete
:rtype: bool
"""
result = auto.is_image_complete(fname)
poller.debug("Image complete:", fname, "->", result)
return result
def process_image(fname, output_dir, poller):
"""
Method for processing an image.
:param fname: the image to process
:type fname: str
:param output_dir: the directory to write the image to
:type output_dir: str
:param poller: the Poller instance that called the method
:type poller: Poller
:return: the list of generated output files
:rtype: list
"""
result = []
try:
# TODO batches?
with paddle.no_grad():
data = preprocess(fname, transforms)
pred, _ = infer.inference(
model,
data['img'],
trans_info=data['trans_info'],
is_slide=False,
stride=False,
crop_size=None,
use_multilabel=False)
fname_out = os.path.join(output_dir, os.path.splitext(os.path.basename(fname))[0] + ".png")
fname_out = prediction_to_file(pred, poller.params.prediction_format, fname_out,
mask_nth=poller.params.mask_nth, classes=poller.params.classes)
result.append(fname_out)
except KeyboardInterrupt:
poller.keyboard_interrupt()
except:
poller.error("Failed to process image: %s\n%s" % (fname, traceback.format_exc()))
return result
def predict_on_images(input_dir, model, transforms, output_dir, tmp_dir, prediction_format="grayscale", labels=None,
mask_nth=1, poll_wait=1.0, continuous=False, use_watchdog=False, watchdog_check_interval=10.0,
delete_input=False, verbose=False, quiet=False):
"""
Method for performing predictions on images.
:param input_dir: the directory with the images
:type input_dir: str
:param model: the paddleseg trained model
:param transforms: the transformations to apply to the images
:param output_dir: the output directory to move the images to and store the predictions
:type output_dir: str
:param tmp_dir: the temporary directory to store the predictions until finished, use None if not to use
:type tmp_dir: str
:param prediction_format: the format to use for the prediction images (grayscale/bluechannel)
:type prediction_format: str
:param labels: the path to the file with the labels (one per line, including background)
:type labels: str
:param mask_nth: the contour tracing can be slow for large masks, by using only every nth row/col, this can be sped up dramatically
:type mask_nth: int
:param poll_wait: the amount of seconds between polls when not in watchdog mode
:type poll_wait: float
:param continuous: whether to poll continuously
:type continuous: bool
:param use_watchdog: whether to react to file creation events rather than use fixed-interval polling
:type use_watchdog: bool
:param watchdog_check_interval: the interval for the watchdog process to check for files that were missed due to potential race conditions
:type watchdog_check_interval: float
:param delete_input: whether to delete the input images rather than moving them to the output directory
:type delete_input: bool
:param verbose: whether to output more logging information
:type verbose: bool
:param quiet: whether to suppress output
:type quiet: bool
"""
poller = Poller()
poller.input_dir = input_dir
poller.output_dir = output_dir
poller.tmp_dir = tmp_dir
poller.extensions = SUPPORTED_EXTS
poller.delete_input = delete_input
poller.progress = not quiet
poller.verbose = verbose
poller.check_file = check_image
poller.process_file = process_image
poller.poll_wait = poll_wait
poller.continuous = continuous
poller.use_watchdog = use_watchdog
poller.watchdog_check_interval = watchdog_check_interval
poller.params.model = model
poller.params.transforms = transforms
poller.params.prediction_format = prediction_format
poller.params.mask_nth = mask_nth
poller.params.classes = classes_dict(labels)
poller.poll()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="PaddleSeg - Prediction", prog="paddleseg_predict_poll", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--config', help='Path to the config file', required=True, default=None)
parser.add_argument('--model_path', help='Path to the trained model (.pdparams file)', required=True, default=None)
parser.add_argument('--device', help='The device to use', default="gpu:0")
parser.add_argument('--labels', help='Path to the text file with the labels; one per line, including background', required=True, default=None)
parser.add_argument('--prediction_in', help='Path to the test images', required=True, default=None)
parser.add_argument('--prediction_out', help='Path to the output csv files folder', required=True, default=None)
parser.add_argument('--prediction_tmp', help='Path to the temporary csv files folder', required=False, default=None)
parser.add_argument('--prediction_format', default=PREDICTION_FORMAT_GRAYSCALE, choices=PREDICTION_FORMATS, help='The format for the prediction images')
parser.add_argument('--mask_nth', type=int, help='To speed polygon detection up, use every nth row and column only (OPEX format only)', required=False, default=1)
parser.add_argument('--poll_wait', type=float, help='poll interval in seconds when not using watchdog mode', required=False, default=1.0)
parser.add_argument('--continuous', action='store_true', help='Whether to continuously load test images and perform prediction', required=False, default=False)
parser.add_argument('--use_watchdog', action='store_true', help='Whether to react to file creation events rather than performing fixed-interval polling', required=False, default=False)
parser.add_argument('--watchdog_check_interval', type=float, help='check interval in seconds for the watchdog', required=False, default=10.0)
parser.add_argument('--delete_input', action='store_true', help='Whether to delete the input images rather than move them to --prediction_out directory', required=False, default=False)
parser.add_argument('--verbose', action='store_true', help='Whether to output more logging info', required=False, default=False)
parser.add_argument('--quiet', action='store_true', help='Whether to suppress output', required=False, default=False)
parsed = parser.parse_args()
try:
model, transforms = load_model(parsed.config, parsed.model_path, parsed.device)
# Performing the prediction and producing the predictions files
predict_on_images(parsed.prediction_in, model, transforms, parsed.prediction_out, parsed.prediction_tmp,
prediction_format=parsed.prediction_format, labels=parsed.labels, continuous=parsed.continuous,
use_watchdog=parsed.use_watchdog, watchdog_check_interval=parsed.watchdog_check_interval,
delete_input=parsed.delete_input, verbose=parsed.verbose, quiet=parsed.quiet)
except Exception as e:
print(traceback.format_exc())