|
14 | 14 |
|
15 | 15 | """Pipeline to run the prediction on the images folder with Triton server."""
|
16 | 16 |
|
| 17 | +import os |
| 18 | +import subprocess |
| 19 | +from absl import app |
17 | 20 | from absl import flags
|
| 21 | +import cv2 |
| 22 | +import numpy as np |
| 23 | +from official.projects.waste_identification_ml.model_inference import color_and_property_extractor |
| 24 | +from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import feature_extraction |
| 25 | +from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import ffmpeg_ops |
| 26 | +from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import mask_bbox_saver |
| 27 | +from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import triton_server_inference |
| 28 | +from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import utils |
18 | 29 |
|
19 | 30 | INPUT_DIRECTORY = flags.DEFINE_string(
|
20 | 31 | "input_directory", None, "The path to the directory containing images."
|
|
71 | 82 | TABLE_ID = flags.DEFINE_string(
|
72 | 83 | "bq_table_id", "Circularnet_table", "BigQuery Table ID for features data"
|
73 | 84 | )
|
| 85 | + |
| 86 | +TRACKING_VISUALIZATION = flags.DEFINE_boolean( |
| 87 | + "tracking_visualization", |
| 88 | + False, |
| 89 | + "If True, visualize the tracking results.", |
| 90 | +) |
| 91 | + |
| 92 | +CROPPED_OBJECTS = flags.DEFINE_boolean( |
| 93 | + "cropped_objects", |
| 94 | + False, |
| 95 | + "If True, save cropped objects per category from the output.", |
| 96 | +) |
| 97 | + |
| 98 | +AREA_THRESHOLD = None |
| 99 | +HEIGHT_TRACKING = 300 |
| 100 | +WIDTH_TRACKING = 300 |
| 101 | + |
| 102 | + |
| 103 | +def main(_) -> None: |
| 104 | + # Check if the input and output directories are valid. |
| 105 | + if ( |
| 106 | + not INPUT_DIRECTORY.value |
| 107 | + or not OUTPUT_DIRECTORY.value |
| 108 | + or not INPUT_DIRECTORY.value.startswith("gs://") |
| 109 | + or not OUTPUT_DIRECTORY.value.startswith("gs://") |
| 110 | + ): |
| 111 | + raise ValueError("Bucket path must be non-empty starting with 'gs://'") |
| 112 | + |
| 113 | + # Copy the images folder from GCP to the present directory. |
| 114 | + input_directory = (INPUT_DIRECTORY.value).rstrip("/\\") |
| 115 | + command = f"gsutil -m cp -r {input_directory} ." |
| 116 | + subprocess.run(command, shell=True, check=True) |
| 117 | + |
| 118 | + # Create a folder to store the predictions. |
| 119 | + prediction_folder = os.path.basename(input_directory) + "_prediction" |
| 120 | + os.makedirs(prediction_folder, exist_ok=True) |
| 121 | + |
| 122 | + if TRACKING_VISUALIZATION.value: |
| 123 | + # Create a folder to troubleshoot tracking results. |
| 124 | + tracking_folder = os.path.basename(input_directory) + "_tracking" |
| 125 | + os.makedirs(tracking_folder, exist_ok=True) |
| 126 | + |
| 127 | + if CROPPED_OBJECTS.value: |
| 128 | + # Create a folder to save cropped objects per category from the output. |
| 129 | + cropped_obj_folder = os.path.basename(input_directory) + "_cropped_objects" |
| 130 | + os.makedirs(cropped_obj_folder, exist_ok=True) |
| 131 | + |
| 132 | + # Create a log directory and a logger for logging. |
| 133 | + log_name = os.path.basename(INPUT_DIRECTORY.value) |
| 134 | + log_folder = os.path.join(os.getcwd(), "logs") |
| 135 | + os.makedirs(log_folder, exist_ok=True) |
| 136 | + logger = utils.create_log_file(log_name, log_folder) |
| 137 | + |
| 138 | + # Read the labels which the model is trained on. |
| 139 | + labels_path = os.path.join(os.getcwd(), "labels.csv") |
| 140 | + labels, category_index = utils.load_labels(labels_path) |
| 141 | + |
| 142 | + # Read files from a folder. |
| 143 | + files = utils.files_paths(os.path.basename(input_directory)) |
| 144 | + |
| 145 | + tracking_images = {} |
| 146 | + features_set = [] |
| 147 | + image_plot = None |
| 148 | + |
| 149 | + for frame, image_path in enumerate(files, start=1): |
| 150 | + # Prepare an input for a Triton model server from an image. |
| 151 | + logger.info(f"\nProcessing {os.path.basename(image_path)}") |
| 152 | + try: |
| 153 | + inputs, original_image, _ = ( |
| 154 | + triton_server_inference.prepare_image( |
| 155 | + image_path, HEIGHT.value, WIDTH.value |
| 156 | + ) |
| 157 | + ) |
| 158 | + |
| 159 | + # Extract the creation time of an image. |
| 160 | + creation_time = ffmpeg_ops.get_image_creation_time(image_path) |
| 161 | + logger.info( |
| 162 | + f"Successfully read an image for {os.path.basename(image_path)}" |
| 163 | + ) |
| 164 | + except (cv2.error, ValueError, TypeError) as e: |
| 165 | + logger.info("Failed to read an image.") |
| 166 | + logger.exception("Exception occurred:", e) |
| 167 | + continue |
| 168 | + |
| 169 | + try: |
| 170 | + model_name = MODEL.value |
| 171 | + result = triton_server_inference.infer(model_name, inputs) |
| 172 | + logger.info( |
| 173 | + f"Successfully got prediction for {os.path.basename(image_path)}" |
| 174 | + ) |
| 175 | + logger.info(f"Total predictions:{result['num_detections'][0]}") |
| 176 | + except (KeyError, TypeError, RuntimeError, ValueError) as e: |
| 177 | + logger.info( |
| 178 | + f"Failed to get prediction for {os.path.basename(image_path)}" |
| 179 | + ) |
| 180 | + logger.exception("Exception occurred:", e) |
| 181 | + continue |
| 182 | + |
| 183 | + try: |
| 184 | + # Take predictions only above the threshold. |
| 185 | + if result["num_detections"][0]: |
| 186 | + scores = result["detection_scores"][0] |
| 187 | + filtered_indices = scores > PREDICTION_THRESHOLD.value |
| 188 | + |
| 189 | + if any(filtered_indices): |
| 190 | + result = utils.filter_detections(result, filtered_indices) |
| 191 | + else: |
| 192 | + logger.info("Zero predictions after threshold.") |
| 193 | + continue |
| 194 | + except (KeyError, IndexError, TypeError, ValueError) as e: |
| 195 | + logger.info("Failed to filter out predictions.") |
| 196 | + logger.exception("Exception occured:", e) |
| 197 | + |
| 198 | + try: |
| 199 | + # Convert bbox coordinates into normalized coordinates. |
| 200 | + if result["num_detections"][0]: |
| 201 | + result["normalized_boxes"] = result["detection_boxes"].copy() |
| 202 | + result["normalized_boxes"][:, :, [0, 2]] /= HEIGHT |
| 203 | + result["normalized_boxes"][:, :, [1, 3]] /= WIDTH |
| 204 | + result["detection_boxes"] = ( |
| 205 | + result["detection_boxes"].round().astype(int) |
| 206 | + ) |
| 207 | + |
| 208 | + # Adjust the image size to ensure both dimensions are at least 1024 |
| 209 | + # for saving images with bbox and masks. |
| 210 | + height_plot, width_plot = utils.adjust_image_size( |
| 211 | + original_image.shape[0], original_image.shape[1], 1024 |
| 212 | + ) |
| 213 | + |
| 214 | + # Resize the original image to overlay bbox and masks on it. |
| 215 | + image_plot = cv2.resize( |
| 216 | + original_image, |
| 217 | + (width_plot, height_plot), |
| 218 | + interpolation=cv2.INTER_AREA, |
| 219 | + ) |
| 220 | + |
| 221 | + # Reframe the masks according to the new image size. |
| 222 | + result["detection_masks_reframed"] = utils.reframe_masks( |
| 223 | + result, "normalized_boxes", height_plot, width_plot |
| 224 | + ) |
| 225 | + |
| 226 | + # Filter the prediction results and remove the overlapping masks. |
| 227 | + unique_indices = utils.filter_masks( |
| 228 | + result["detection_masks_reframed"], |
| 229 | + iou_threshold=0.08, |
| 230 | + area_threshold=AREA_THRESHOLD, |
| 231 | + ) |
| 232 | + result = utils.filter_detections(result, unique_indices) |
| 233 | + logger.info( |
| 234 | + f"Total predictions after processing: {result['num_detections'][0]}" |
| 235 | + ) |
| 236 | + else: |
| 237 | + logger.info("Zero predictions after processing.") |
| 238 | + continue |
| 239 | + except (KeyError, IndexError, TypeError, ValueError) as e: |
| 240 | + logger.info("Issue in post processing predictions results.") |
| 241 | + logger.exception("Exception occured:", e) |
| 242 | + |
| 243 | + try: |
| 244 | + if result["num_detections"][0]: |
| 245 | + result["detection_classes_names"] = np.array( |
| 246 | + [[str(labels[i - 1]) for i in result["detection_classes"][0]]] |
| 247 | + ) |
| 248 | + |
| 249 | + # Save the prediction results as an image file with bbx and masks. |
| 250 | + mask_bbox_saver.save_bbox_masks_labels( |
| 251 | + result=result, |
| 252 | + image=image_plot, |
| 253 | + file_name=os.path.basename(image_path), |
| 254 | + folder=prediction_folder, |
| 255 | + category_index=category_index, |
| 256 | + threshold=PREDICTION_THRESHOLD.value, |
| 257 | + ) |
| 258 | + except (KeyError, IndexError, TypeError, ValueError) as e: |
| 259 | + logger.info("Issue in saving visualization of results.") |
| 260 | + logger.exception("Exception occured:", e) |
| 261 | + |
| 262 | + try: |
| 263 | + # Resize an image for object tracking.. |
| 264 | + tracking_image = cv2.resize( |
| 265 | + original_image, |
| 266 | + (WIDTH_TRACKING, HEIGHT_TRACKING), |
| 267 | + interpolation=cv2.INTER_AREA, |
| 268 | + ) |
| 269 | + tracking_images[os.path.basename(image_path)] = tracking_image |
| 270 | + |
| 271 | + # Reducing mask sizes in order to keep the memory required for object |
| 272 | + # tracking under a threshold. |
| 273 | + result["detection_masks_tracking"] = np.array([ |
| 274 | + cv2.resize( |
| 275 | + i, |
| 276 | + (WIDTH_TRACKING, HEIGHT_TRACKING), |
| 277 | + interpolation=cv2.INTER_NEAREST, |
| 278 | + ) |
| 279 | + for i in result["detection_masks_reframed"] |
| 280 | + ]) |
| 281 | + |
| 282 | + # Crop objects from an image using masks for color detection. |
| 283 | + cropped_objects = [ |
| 284 | + np.where(np.expand_dims(i, -1), image_plot, 0) |
| 285 | + for i in result["detection_masks_reframed"] |
| 286 | + ] |
| 287 | + |
| 288 | + # Perform color detection using clustering approach. |
| 289 | + dominant_colors = [ |
| 290 | + *map( |
| 291 | + color_and_property_extractor.find_dominant_color, cropped_objects |
| 292 | + ) |
| 293 | + ] |
| 294 | + generic_color_names = color_and_property_extractor.get_generic_color_name( |
| 295 | + dominant_colors |
| 296 | + ) |
| 297 | + |
| 298 | + # Extract features. |
| 299 | + features = feature_extraction.extract_properties( |
| 300 | + tracking_image, result, "detection_masks_tracking" |
| 301 | + ) |
| 302 | + features["source_name"] = os.path.basename(os.path.dirname(image_path)) |
| 303 | + features["image_name"] = os.path.basename(image_path) |
| 304 | + features["creation_time"] = creation_time |
| 305 | + features["frame"] = frame |
| 306 | + features["detection_scores"] = result["detection_scores"][0] |
| 307 | + features["detection_classes"] = result["detection_classes"][0] |
| 308 | + features["detection_classes_names"] = result["detection_classes_names"][0] |
| 309 | + features["color"] = generic_color_names |
| 310 | + features_set.append(features) |
| 311 | + except (KeyError, IndexError, TypeError, ValueError) as e: |
| 312 | + logger.info("Failed to extract properties.") |
| 313 | + logger.exception("Exception occured:", e) |
| 314 | + |
| 315 | + |
| 316 | +if __name__ == "__main__": |
| 317 | + app.run(main) |
0 commit comments