Skip to content

Commit 26cca22

Browse files
No public description
PiperOrigin-RevId: 764863901
1 parent 9e77955 commit 26cca22

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

official/projects/waste_identification_ml/Triton_TF_Cloud_Deployment/client/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from collections.abc import Mapping, Sequence
1818
import csv
19+
import dataclasses
1920
import logging
2021
import os
2122
from typing import Any, TypedDict
@@ -31,6 +32,20 @@ class ItemDict(TypedDict):
3132
supercategory: str
3233

3334

35+
@dataclasses.dataclass
36+
class BoundingBox:
37+
y1: int | float
38+
x1: int | float
39+
y2: int | float
40+
x2: int | float
41+
42+
43+
@dataclasses.dataclass
44+
class ImageSize:
45+
height: int
46+
width: int
47+
48+
3449
def _reframe_image_corners_relative_to_boxes(boxes: tf.Tensor) -> tf.Tensor:
3550
"""Reframe the image corners ([0, 0, 1, 1]) to be relative to boxes.
3651
@@ -455,3 +470,27 @@ def filter_detections(
455470
filtered_output['num_detections'] = np.array([new_num_detections])
456471

457472
return filtered_output
473+
474+
475+
def resize_bbox(
476+
bbox: BoundingBox, old_size: ImageSize, new_size: ImageSize
477+
) -> tuple[int, int, int, int]:
478+
"""Resize bounding box coordinates based on new image size.
479+
480+
Args:
481+
bbox: BoundingBox with original coordinates.
482+
old_size: Original image size.
483+
new_size: New image size.
484+
485+
Returns:
486+
Rescaled bounding box coordinates.
487+
"""
488+
scale_x = new_size.width / old_size.width
489+
scale_y = new_size.height / old_size.height
490+
491+
new_y1 = int(bbox.y1 * scale_y)
492+
new_x1 = int(bbox.x1 * scale_x)
493+
new_y2 = int(bbox.y2 * scale_y)
494+
new_x2 = int(bbox.x2 * scale_x)
495+
496+
return new_y1, new_x1, new_y2, new_x2

official/projects/waste_identification_ml/Triton_TF_Cloud_Deployment/client/utils_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import unittest
1818
import numpy as np
1919
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import utils
20+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client.utils import BoundingBox
21+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client.utils import ImageSize
2022

2123

2224
class TestLoadLabels(unittest.TestCase):
@@ -222,6 +224,26 @@ def test_extract_and_resize_single_object(self):
222224
# The output pixels in mask area should be non-zero
223225
self.assertTrue(np.any(obj > 0))
224226

227+
def test_resize_bbox_scaling(self):
228+
bbox = BoundingBox(y1=50, x1=100, y2=150, x2=200)
229+
old_size = ImageSize(height=200, width=400)
230+
new_size = ImageSize(height=400, width=800)
231+
232+
expected = (100, 200, 300, 400)
233+
result = utils.resize_bbox(bbox, old_size, new_size)
234+
235+
self.assertEqual(result, expected)
236+
237+
def test_resize_bbox_no_scaling(self):
238+
bbox = BoundingBox(y1=10, x1=20, y2=30, x2=40)
239+
old_size = ImageSize(height=100, width=100)
240+
new_size = ImageSize(height=100, width=100)
241+
242+
expected = (10, 20, 30, 40)
243+
result = utils.resize_bbox(bbox, old_size, new_size)
244+
245+
self.assertEqual(result, expected)
246+
225247

226248
if __name__ == '__main__':
227249
unittest.main()

0 commit comments

Comments
 (0)