Skip to content

Commit 9e77955

Browse files
No public description
PiperOrigin-RevId: 764604204
1 parent 677bcef commit 9e77955

File tree

1 file changed

+244
-0
lines changed

1 file changed

+244
-0
lines changed

official/projects/waste_identification_ml/Triton_TF_Cloud_Deployment/client/inference_pipeline.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,18 @@
1414

1515
"""Pipeline to run the prediction on the images folder with Triton server."""
1616

17+
import os
18+
import subprocess
19+
from absl import app
1720
from absl import flags
21+
import cv2
22+
import numpy as np
23+
from official.projects.waste_identification_ml.model_inference import color_and_property_extractor
24+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import feature_extraction
25+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import ffmpeg_ops
26+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import mask_bbox_saver
27+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import triton_server_inference
28+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import utils
1829

1930
INPUT_DIRECTORY = flags.DEFINE_string(
2031
"input_directory", None, "The path to the directory containing images."
@@ -71,3 +82,236 @@
7182
TABLE_ID = flags.DEFINE_string(
7283
"bq_table_id", "Circularnet_table", "BigQuery Table ID for features data"
7384
)
85+
86+
TRACKING_VISUALIZATION = flags.DEFINE_boolean(
87+
"tracking_visualization",
88+
False,
89+
"If True, visualize the tracking results.",
90+
)
91+
92+
CROPPED_OBJECTS = flags.DEFINE_boolean(
93+
"cropped_objects",
94+
False,
95+
"If True, save cropped objects per category from the output.",
96+
)
97+
98+
AREA_THRESHOLD = None
99+
HEIGHT_TRACKING = 300
100+
WIDTH_TRACKING = 300
101+
102+
103+
def main(_) -> None:
104+
# Check if the input and output directories are valid.
105+
if (
106+
not INPUT_DIRECTORY.value
107+
or not OUTPUT_DIRECTORY.value
108+
or not INPUT_DIRECTORY.value.startswith("gs://")
109+
or not OUTPUT_DIRECTORY.value.startswith("gs://")
110+
):
111+
raise ValueError("Bucket path must be non-empty starting with 'gs://'")
112+
113+
# Copy the images folder from GCP to the present directory.
114+
input_directory = (INPUT_DIRECTORY.value).rstrip("/\\")
115+
command = f"gsutil -m cp -r {input_directory} ."
116+
subprocess.run(command, shell=True, check=True)
117+
118+
# Create a folder to store the predictions.
119+
prediction_folder = os.path.basename(input_directory) + "_prediction"
120+
os.makedirs(prediction_folder, exist_ok=True)
121+
122+
if TRACKING_VISUALIZATION.value:
123+
# Create a folder to troubleshoot tracking results.
124+
tracking_folder = os.path.basename(input_directory) + "_tracking"
125+
os.makedirs(tracking_folder, exist_ok=True)
126+
127+
if CROPPED_OBJECTS.value:
128+
# Create a folder to save cropped objects per category from the output.
129+
cropped_obj_folder = os.path.basename(input_directory) + "_cropped_objects"
130+
os.makedirs(cropped_obj_folder, exist_ok=True)
131+
132+
# Create a log directory and a logger for logging.
133+
log_name = os.path.basename(INPUT_DIRECTORY.value)
134+
log_folder = os.path.join(os.getcwd(), "logs")
135+
os.makedirs(log_folder, exist_ok=True)
136+
logger = utils.create_log_file(log_name, log_folder)
137+
138+
# Read the labels which the model is trained on.
139+
labels_path = os.path.join(os.getcwd(), "labels.csv")
140+
labels, category_index = utils.load_labels(labels_path)
141+
142+
# Read files from a folder.
143+
files = utils.files_paths(os.path.basename(input_directory))
144+
145+
tracking_images = {}
146+
features_set = []
147+
image_plot = None
148+
149+
for frame, image_path in enumerate(files, start=1):
150+
# Prepare an input for a Triton model server from an image.
151+
logger.info(f"\nProcessing {os.path.basename(image_path)}")
152+
try:
153+
inputs, original_image, _ = (
154+
triton_server_inference.prepare_image(
155+
image_path, HEIGHT.value, WIDTH.value
156+
)
157+
)
158+
159+
# Extract the creation time of an image.
160+
creation_time = ffmpeg_ops.get_image_creation_time(image_path)
161+
logger.info(
162+
f"Successfully read an image for {os.path.basename(image_path)}"
163+
)
164+
except (cv2.error, ValueError, TypeError) as e:
165+
logger.info("Failed to read an image.")
166+
logger.exception("Exception occurred:", e)
167+
continue
168+
169+
try:
170+
model_name = MODEL.value
171+
result = triton_server_inference.infer(model_name, inputs)
172+
logger.info(
173+
f"Successfully got prediction for {os.path.basename(image_path)}"
174+
)
175+
logger.info(f"Total predictions:{result['num_detections'][0]}")
176+
except (KeyError, TypeError, RuntimeError, ValueError) as e:
177+
logger.info(
178+
f"Failed to get prediction for {os.path.basename(image_path)}"
179+
)
180+
logger.exception("Exception occurred:", e)
181+
continue
182+
183+
try:
184+
# Take predictions only above the threshold.
185+
if result["num_detections"][0]:
186+
scores = result["detection_scores"][0]
187+
filtered_indices = scores > PREDICTION_THRESHOLD.value
188+
189+
if any(filtered_indices):
190+
result = utils.filter_detections(result, filtered_indices)
191+
else:
192+
logger.info("Zero predictions after threshold.")
193+
continue
194+
except (KeyError, IndexError, TypeError, ValueError) as e:
195+
logger.info("Failed to filter out predictions.")
196+
logger.exception("Exception occured:", e)
197+
198+
try:
199+
# Convert bbox coordinates into normalized coordinates.
200+
if result["num_detections"][0]:
201+
result["normalized_boxes"] = result["detection_boxes"].copy()
202+
result["normalized_boxes"][:, :, [0, 2]] /= HEIGHT
203+
result["normalized_boxes"][:, :, [1, 3]] /= WIDTH
204+
result["detection_boxes"] = (
205+
result["detection_boxes"].round().astype(int)
206+
)
207+
208+
# Adjust the image size to ensure both dimensions are at least 1024
209+
# for saving images with bbox and masks.
210+
height_plot, width_plot = utils.adjust_image_size(
211+
original_image.shape[0], original_image.shape[1], 1024
212+
)
213+
214+
# Resize the original image to overlay bbox and masks on it.
215+
image_plot = cv2.resize(
216+
original_image,
217+
(width_plot, height_plot),
218+
interpolation=cv2.INTER_AREA,
219+
)
220+
221+
# Reframe the masks according to the new image size.
222+
result["detection_masks_reframed"] = utils.reframe_masks(
223+
result, "normalized_boxes", height_plot, width_plot
224+
)
225+
226+
# Filter the prediction results and remove the overlapping masks.
227+
unique_indices = utils.filter_masks(
228+
result["detection_masks_reframed"],
229+
iou_threshold=0.08,
230+
area_threshold=AREA_THRESHOLD,
231+
)
232+
result = utils.filter_detections(result, unique_indices)
233+
logger.info(
234+
f"Total predictions after processing: {result['num_detections'][0]}"
235+
)
236+
else:
237+
logger.info("Zero predictions after processing.")
238+
continue
239+
except (KeyError, IndexError, TypeError, ValueError) as e:
240+
logger.info("Issue in post processing predictions results.")
241+
logger.exception("Exception occured:", e)
242+
243+
try:
244+
if result["num_detections"][0]:
245+
result["detection_classes_names"] = np.array(
246+
[[str(labels[i - 1]) for i in result["detection_classes"][0]]]
247+
)
248+
249+
# Save the prediction results as an image file with bbx and masks.
250+
mask_bbox_saver.save_bbox_masks_labels(
251+
result=result,
252+
image=image_plot,
253+
file_name=os.path.basename(image_path),
254+
folder=prediction_folder,
255+
category_index=category_index,
256+
threshold=PREDICTION_THRESHOLD.value,
257+
)
258+
except (KeyError, IndexError, TypeError, ValueError) as e:
259+
logger.info("Issue in saving visualization of results.")
260+
logger.exception("Exception occured:", e)
261+
262+
try:
263+
# Resize an image for object tracking..
264+
tracking_image = cv2.resize(
265+
original_image,
266+
(WIDTH_TRACKING, HEIGHT_TRACKING),
267+
interpolation=cv2.INTER_AREA,
268+
)
269+
tracking_images[os.path.basename(image_path)] = tracking_image
270+
271+
# Reducing mask sizes in order to keep the memory required for object
272+
# tracking under a threshold.
273+
result["detection_masks_tracking"] = np.array([
274+
cv2.resize(
275+
i,
276+
(WIDTH_TRACKING, HEIGHT_TRACKING),
277+
interpolation=cv2.INTER_NEAREST,
278+
)
279+
for i in result["detection_masks_reframed"]
280+
])
281+
282+
# Crop objects from an image using masks for color detection.
283+
cropped_objects = [
284+
np.where(np.expand_dims(i, -1), image_plot, 0)
285+
for i in result["detection_masks_reframed"]
286+
]
287+
288+
# Perform color detection using clustering approach.
289+
dominant_colors = [
290+
*map(
291+
color_and_property_extractor.find_dominant_color, cropped_objects
292+
)
293+
]
294+
generic_color_names = color_and_property_extractor.get_generic_color_name(
295+
dominant_colors
296+
)
297+
298+
# Extract features.
299+
features = feature_extraction.extract_properties(
300+
tracking_image, result, "detection_masks_tracking"
301+
)
302+
features["source_name"] = os.path.basename(os.path.dirname(image_path))
303+
features["image_name"] = os.path.basename(image_path)
304+
features["creation_time"] = creation_time
305+
features["frame"] = frame
306+
features["detection_scores"] = result["detection_scores"][0]
307+
features["detection_classes"] = result["detection_classes"][0]
308+
features["detection_classes_names"] = result["detection_classes_names"][0]
309+
features["color"] = generic_color_names
310+
features_set.append(features)
311+
except (KeyError, IndexError, TypeError, ValueError) as e:
312+
logger.info("Failed to extract properties.")
313+
logger.exception("Exception occured:", e)
314+
315+
316+
if __name__ == "__main__":
317+
app.run(main)

0 commit comments

Comments
 (0)